/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import java.io.InputStream;
import java.io.OutputStream;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Persistable;
import org.deeplearning4j.util.SerializationUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class Viterbi
implements Persistable {
    private double metaStability = 0.9;
    private double pCorrect = 0.99;
    private INDArray possibleLabels;
    private int states;
    private double logPCorrect = FastMath.log((double)this.pCorrect);
    private double logPIncorrect = FastMath.log((double)(1.0 - this.pCorrect / (double)this.states - 1.0));
    private double logMetaInstability = Math.log(this.metaStability);
    private double logOfDiangnalTProb;
    private double logStates;

    public Viterbi(INDArray possibleLabels) {
        this.possibleLabels = possibleLabels;
        this.states = possibleLabels.length();
        this.logOfDiangnalTProb = FastMath.log((double)(1.0 - this.metaStability / (double)this.states - 1.0));
        this.logStates = FastMath.log((double)this.states);
    }

    public Pair<Double, INDArray> decode(INDArray labels) {
        return this.decode(labels, true);
    }

    public Pair<Double, INDArray> decode(INDArray labels, boolean binaryLabelMatrix) {
        INDArray outcomeSequence = labels.isColumnVector() || labels.isRowVector() || binaryLabelMatrix ? this.toOutcomesFromBinaryLabelMatrix(labels) : labels;
        int frames = outcomeSequence.length();
        INDArray V = Nd4j.ones((int)frames, (int)this.states);
        INDArray pointers = Nd4j.zeros((int)frames, (int)this.states);
        INDArray assigned = V.getRow(0);
        assigned.assign((Number)(this.logPCorrect - this.logStates));
        V.putRow(0, assigned);
        V.put(0, (int)outcomeSequence.getDouble(0), (Number)(this.logPCorrect - this.logStates));
        for (int t = 1; t < frames; ++t) {
            for (int k = 0; k < this.states; ++k) {
                INDArray rowLogProduct = this.rowOfLogTransitionMatrix(k).add(V.getRow(t - 1));
                int maxVal = Nd4j.getBlasWrapper().iamax(rowLogProduct);
                double argMax = rowLogProduct.max(Integer.MAX_VALUE).getDouble(0);
                V.put(t, k, (Number)argMax);
                int element = (int)outcomeSequence.getDouble(t);
                if (k == element) {
                    V.put(t, k, (Number)(this.logPCorrect + (double)maxVal));
                    continue;
                }
                V.put(t, k, (Number)(this.logPIncorrect + (double)maxVal));
            }
        }
        INDArray rectified = Nd4j.zeros((int)frames);
        rectified.put(rectified.length() - 1, V.getRow(frames - 1).max(Integer.MAX_VALUE));
        for (int t = rectified.length() - 2; t > 0; --t) {
            rectified.putScalar(t, pointers.getDouble(t + 1, (int)rectified.getDouble(t + 1)));
        }
        return new Pair<Double, INDArray>(V.getRow(frames - 1).max(Integer.MAX_VALUE).getDouble(0), rectified);
    }

    private INDArray rowOfLogTransitionMatrix(int k) {
        INDArray row = Nd4j.ones((int)1, (int)this.states).muli((Number)this.logOfDiangnalTProb);
        row.putScalar(k, this.logMetaInstability);
        return row;
    }

    private INDArray toOutcomesFromBinaryLabelMatrix(INDArray outcomes) {
        INDArray ret = Nd4j.create((int)outcomes.rows(), (int)1);
        for (int i = 0; i < outcomes.rows(); ++i) {
            ret.put(i, 0, (Number)Nd4j.getBlasWrapper().iamax(outcomes.getRow(i)));
        }
        return ret;
    }

    @Override
    public void write(OutputStream os) {
        SerializationUtils.writeObject(this, os);
    }

    @Override
    public void load(InputStream is) {
        Viterbi ret = (Viterbi)SerializationUtils.readObject(is);
        this.states = ret.states;
        this.logStates = ret.logStates;
        this.metaStability = ret.metaStability;
        this.logMetaInstability = ret.logMetaInstability;
        this.logOfDiangnalTProb = ret.logOfDiangnalTProb;
        this.logPCorrect = ret.logPCorrect;
        this.pCorrect = ret.pCorrect;
    }

    public double getMetaStability() {
        return this.metaStability;
    }

    public void setMetaStability(double metaStability) {
        this.metaStability = metaStability;
    }

    public double getpCorrect() {
        return this.pCorrect;
    }

    public void setpCorrect(double pCorrect) {
        this.pCorrect = pCorrect;
    }

    public INDArray getPossibleLabels() {
        return this.possibleLabels;
    }

    public void setPossibleLabels(INDArray possibleLabels) {
        this.possibleLabels = possibleLabels;
    }

    public int getStates() {
        return this.states;
    }

    public void setStates(int states) {
        this.states = states;
    }

    public double getLogPCorrect() {
        return this.logPCorrect;
    }

    public void setLogPCorrect(double logPCorrect) {
        this.logPCorrect = logPCorrect;
    }

    public double getLogPIncorrect() {
        return this.logPIncorrect;
    }

    public void setLogPIncorrect(double logPIncorrect) {
        this.logPIncorrect = logPIncorrect;
    }

    public double getLogMetaInstability() {
        return this.logMetaInstability;
    }

    public void setLogMetaInstability(double logMetaInstability) {
        this.logMetaInstability = logMetaInstability;
    }

    public double getLogOfDiangnalTProb() {
        return this.logOfDiangnalTProb;
    }

    public void setLogOfDiangnalTProb(double logOfDiangnalTProb) {
        this.logOfDiangnalTProb = logOfDiangnalTProb;
    }

    public double getLogStates() {
        return this.logStates;
    }

    public void setLogStates(double logStates) {
        this.logStates = logStates;
    }
}

