package org.allenai.ml.sequences;

import com.gs.collections.api.list.primitive.MutableDoubleList;
import com.gs.collections.impl.list.mutable.primitive.DoubleArrayList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import org.allenai.ml.math.SloppyMath;

/* loaded from: input_file:org/allenai/ml/sequences/ForwardBackwards.class */
public class ForwardBackwards<S> {
    private final StateSpace<S> stateSpace;
    private final int numStates;
    private final int numTransitions;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/allenai/ml/sequences/ForwardBackwards$LogAddRing.class */
    public class LogAddRing implements RingOp {
        MutableDoubleList xs;
        double max = Double.NEGATIVE_INFINITY;

        LogAddRing() {
            this.xs = new DoubleArrayList(ForwardBackwards.this.stateSpace.transitions().size());
        }

        @Override // org.allenai.ml.sequences.ForwardBackwards.RingOp
        public void clear() {
            this.xs.clear();
            this.max = Double.NEGATIVE_INFINITY;
        }

        @Override // org.allenai.ml.sequences.ForwardBackwards.RingOp
        public void add(double d) {
            if (d > this.max) {
                this.max = d;
            }
            if (d - this.max >= -30.0d) {
                this.xs.add(d);
            }
        }

        @Override // org.allenai.ml.sequences.ForwardBackwards.RingOp
        public double compute() {
            double d = 0.0d;
            for (int i = 0; i < this.xs.size(); i++) {
                d += SloppyMath.sloppyExp(this.xs.get(i) - this.max);
            }
            return d > 0.0d ? this.max + Math.log(d) : this.max;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/allenai/ml/sequences/ForwardBackwards$MaxRing.class */
    public class MaxRing implements RingOp {
        private double max = Double.NEGATIVE_INFINITY;

        MaxRing() {
        }

        @Override // org.allenai.ml.sequences.ForwardBackwards.RingOp
        public void clear() {
            this.max = Double.NEGATIVE_INFINITY;
        }

        @Override // org.allenai.ml.sequences.ForwardBackwards.RingOp
        public void add(double d) {
            if (d > this.max) {
                this.max = d;
            }
        }

        @Override // org.allenai.ml.sequences.ForwardBackwards.RingOp
        public double compute() {
            return this.max;
        }
    }

    /* loaded from: input_file:org/allenai/ml/sequences/ForwardBackwards$Result.class */
    public class Result {
        private final double[][] potentials;
        private final int seqLen;
        private final AtomicReference<Object> viterbi;
        private final AtomicReference<Object> alphas;
        private final AtomicReference<Object> betas;
        private final AtomicReference<Object> nodeMarginals;
        private final AtomicReference<Object> edgeMarginals;

        private Result(double[][] dArr) {
            this.viterbi = new AtomicReference<>();
            this.alphas = new AtomicReference<>();
            this.betas = new AtomicReference<>();
            this.nodeMarginals = new AtomicReference<>();
            this.edgeMarginals = new AtomicReference<>();
            this.potentials = dArr;
            this.seqLen = dArr.length + 1;
        }

        public double getLogZ() {
            return getAlphas()[this.seqLen - 1][ForwardBackwards.this.stateSpace.stopStateIndex()];
        }

        private double[][] computeAlphas(RingOp ringOp) {
            double[][] dArr = new double[this.seqLen][ForwardBackwards.this.numStates];
            for (double[] dArr2 : dArr) {
                Arrays.fill(dArr2, Double.NEGATIVE_INFINITY);
            }
            dArr[0][ForwardBackwards.this.stateSpace.startStateIndex()] = 0.0d;
            for (int i = 1; i < this.seqLen; i++) {
                int i2 = i - 1;
                for (int i3 = 0; i3 < ForwardBackwards.this.numStates; i3++) {
                    ringOp.clear();
                    for (Transition transition : ForwardBackwards.this.stateSpace.transitionsTo(i3)) {
                        ringOp.add(dArr[i2][transition.fromState] + this.potentials[i2][transition.selfIndex]);
                    }
                    dArr[i][i3] = ringOp.compute();
                }
            }
            return dArr;
        }

        private double[][] computeBetas() {
            double[][] dArr = new double[this.seqLen][ForwardBackwards.this.numStates];
            for (double[] dArr2 : dArr) {
                Arrays.fill(dArr2, Double.NEGATIVE_INFINITY);
            }
            dArr[this.seqLen - 1][ForwardBackwards.this.stateSpace.stopStateIndex()] = 0.0d;
            LogAddRing logAddRing = new LogAddRing();
            for (int i = this.seqLen - 2; i >= 0; i--) {
                int i2 = i;
                int i3 = i + 1;
                for (int i4 = 0; i4 < ForwardBackwards.this.numStates; i4++) {
                    logAddRing.clear();
                    for (Transition transition : ForwardBackwards.this.stateSpace.transitionsFrom(i4)) {
                        logAddRing.add(dArr[i3][transition.toState] + this.potentials[i2][transition.selfIndex]);
                    }
                    dArr[i][i4] = logAddRing.compute();
                }
            }
            return dArr;
        }

        private List<S> computeViterbi() {
            double[][] computeAlphas = computeAlphas(new MaxRing());
            double d = computeAlphas[this.seqLen - 1][ForwardBackwards.this.stateSpace.stopStateIndex()];
            int stopStateIndex = ForwardBackwards.this.stateSpace.stopStateIndex();
            ArrayList arrayList = new ArrayList();
            for (int i = this.seqLen - 2; i >= 0; i--) {
                int i2 = i;
                double d2 = d;
                Transition transition = null;
                Iterator<Transition> it = ForwardBackwards.this.stateSpace.transitionsTo(stopStateIndex).iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    Transition next = it.next();
                    if (Math.abs((this.potentials[i2][next.selfIndex] + computeAlphas[i2][next.fromState]) - d2) < 1.0E-8d) {
                        transition = next;
                        break;
                    }
                }
                if (transition == null) {
                    throw new RuntimeException("viterbi can't find path found by computeAlphas(MAX)");
                }
                stopStateIndex = transition.fromState;
                d = computeAlphas[i][stopStateIndex];
                if (i > 0) {
                    arrayList.add(ForwardBackwards.this.stateSpace.states().get(stopStateIndex));
                }
            }
            Collections.reverse(arrayList);
            return arrayList;
        }

        private double[][] computeNodeMarginals() {
            double[][] alphas = getAlphas();
            double[][] dArr = new double[this.seqLen][ForwardBackwards.this.numStates];
            double[][] edgeMarginals = getEdgeMarginals();
            dArr[0][ForwardBackwards.this.stateSpace.startStateIndex()] = 1.0d;
            dArr[this.seqLen - 1][ForwardBackwards.this.stateSpace.stopStateIndex()] = 1.0d;
            for (int i = 1; i < this.seqLen - 1; i++) {
                for (int i2 = 0; i2 < ForwardBackwards.this.numStates; i2++) {
                    if (alphas[i][i2] != Double.NEGATIVE_INFINITY) {
                        for (Transition transition : ForwardBackwards.this.stateSpace.transitionsFrom(i2)) {
                            double[] dArr2 = dArr[i];
                            int i3 = i2;
                            dArr2[i3] = dArr2[i3] + edgeMarginals[i][transition.selfIndex];
                        }
                    }
                }
            }
            return dArr;
        }

        private double[][] computeEdgeMarginals() {
            double[][] alphas = getAlphas();
            double[][] betas = getBetas();
            double[][] dArr = new double[this.seqLen - 1][ForwardBackwards.this.numTransitions];
            double logZ = getLogZ();
            for (int i = 0; i < this.seqLen - 1; i++) {
                for (int i2 = 0; i2 < ForwardBackwards.this.numStates; i2++) {
                    if (alphas[i][i2] != Double.NEGATIVE_INFINITY) {
                        for (Transition transition : ForwardBackwards.this.stateSpace.transitionsFrom(i2)) {
                            if (this.potentials[i][transition.selfIndex] != Double.NEGATIVE_INFINITY && betas[i + 1][transition.toState] != Double.NEGATIVE_INFINITY) {
                                dArr[i][transition.selfIndex] = SloppyMath.sloppyExp(((alphas[i][i2] + this.potentials[i][transition.selfIndex]) + betas[i + 1][transition.toState]) - logZ);
                            }
                        }
                    }
                }
            }
            return dArr;
        }

        public List<S> getViterbi() {
            Object obj = this.viterbi.get();
            if (obj == null) {
                synchronized (this.viterbi) {
                    obj = this.viterbi.get();
                    if (obj == null) {
                        List<S> computeViterbi = computeViterbi();
                        obj = computeViterbi == null ? this.viterbi : computeViterbi;
                        this.viterbi.set(obj);
                    }
                }
            }
            return (List) (obj == this.viterbi ? null : obj);
        }

        private double[][] getAlphas() {
            Object obj = this.alphas.get();
            if (obj == null) {
                synchronized (this.alphas) {
                    obj = this.alphas.get();
                    if (obj == null) {
                        double[][] computeAlphas = computeAlphas(new LogAddRing());
                        obj = computeAlphas == null ? this.alphas : computeAlphas;
                        this.alphas.set(obj);
                    }
                }
            }
            return (double[][]) (obj == this.alphas ? (double[][]) null : obj);
        }

        private double[][] getBetas() {
            Object obj = this.betas.get();
            if (obj == null) {
                synchronized (this.betas) {
                    obj = this.betas.get();
                    if (obj == null) {
                        double[][] computeBetas = computeBetas();
                        obj = computeBetas == null ? this.betas : computeBetas;
                        this.betas.set(obj);
                    }
                }
            }
            return (double[][]) (obj == this.betas ? (double[][]) null : obj);
        }

        public double[][] getNodeMarginals() {
            Object obj = this.nodeMarginals.get();
            if (obj == null) {
                synchronized (this.nodeMarginals) {
                    obj = this.nodeMarginals.get();
                    if (obj == null) {
                        double[][] computeNodeMarginals = computeNodeMarginals();
                        obj = computeNodeMarginals == null ? this.nodeMarginals : computeNodeMarginals;
                        this.nodeMarginals.set(obj);
                    }
                }
            }
            return (double[][]) (obj == this.nodeMarginals ? (double[][]) null : obj);
        }

        public double[][] getEdgeMarginals() {
            Object obj = this.edgeMarginals.get();
            if (obj == null) {
                synchronized (this.edgeMarginals) {
                    obj = this.edgeMarginals.get();
                    if (obj == null) {
                        double[][] computeEdgeMarginals = computeEdgeMarginals();
                        obj = computeEdgeMarginals == null ? this.edgeMarginals : computeEdgeMarginals;
                        this.edgeMarginals.set(obj);
                    }
                }
            }
            return (double[][]) (obj == this.edgeMarginals ? (double[][]) null : obj);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/allenai/ml/sequences/ForwardBackwards$RingOp.class */
    public interface RingOp {
        void clear();

        void add(double d);

        double compute();
    }

    public ForwardBackwards(StateSpace<S> stateSpace) {
        this.stateSpace = stateSpace;
        this.numStates = stateSpace.states().size();
        this.numTransitions = stateSpace.transitions().size();
    }

    public ForwardBackwards<S>.Result compute(double[][] dArr) {
        return new Result(dArr);
    }
}
