package org.allenai.ml.sequences.crf;

import com.gs.collections.api.block.function.primitive.IntToIntFunction;
import java.beans.ConstructorProperties;
import java.lang.invoke.SerializedLambda;
import java.util.List;
import java.util.Optional;
import org.allenai.ml.linalg.Vector;
import org.allenai.ml.objective.ExampleObjectiveFn;
import org.allenai.ml.sequences.ForwardBackwards;
import org.allenai.ml.sequences.Transition;

/* loaded from: input_file:org/allenai/ml/sequences/crf/CRFLogLikelihoodObjective.class */
public class CRFLogLikelihoodObjective<S> implements ExampleObjectiveFn<CRFIndexedExample> {
    private final CRFWeightsEncoder<S> weightEncoder;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // org.allenai.ml.objective.ExampleObjectiveFn
    public double evaluate(CRFIndexedExample cRFIndexedExample, Vector vector, Vector vector2) {
        if (!cRFIndexedExample.isLabeled()) {
            throw new IllegalArgumentException("Requires labeled example");
        }
        double[][] fillPotentials = this.weightEncoder.fillPotentials(vector, cRFIndexedExample);
        ForwardBackwards<S>.Result compute = new ForwardBackwards(this.weightEncoder.stateSpace).compute(fillPotentials);
        int[] goldLabels = cRFIndexedExample.getGoldLabels();
        double d = 0.0d;
        for (int i = 0; i + 1 < goldLabels.length; i++) {
            int i2 = goldLabels[i];
            int i3 = goldLabels[i + 1];
            Optional<Transition> transitionFor = this.weightEncoder.stateSpace.transitionFor(i2, i3);
            if (!transitionFor.isPresent()) {
                List<S> states = this.weightEncoder.stateSpace.states();
                throw new IllegalArgumentException(String.format("Gold transition doesn't exist [%s, %s]", states.get(i2), states.get(i3)));
            }
            d += fillPotentials[i][transitionFor.get().selfIndex];
            updateGrad(vector2, cRFIndexedExample.getNodePredicateValues(i), i4 -> {
                return this.weightEncoder.nodeWeightIndex(i4, i2);
            }, 1.0d);
            Vector.Iterator edgePredicateValues = cRFIndexedExample.getEdgePredicateValues(i);
            int i5 = transitionFor.get().selfIndex;
            updateGrad(vector2, edgePredicateValues, i6 -> {
                return this.weightEncoder.edgeWeightIndex(i6, i5);
            }, 1.0d);
        }
        double logZ = compute.getLogZ();
        double[][] nodeMarginals = compute.getNodeMarginals();
        double[][] edgeMarginals = compute.getEdgeMarginals();
        int size = this.weightEncoder.stateSpace.states().size();
        int size2 = this.weightEncoder.stateSpace.transitions().size();
        for (int i7 = 0; i7 + 1 < cRFIndexedExample.getSequenceLength(); i7++) {
            Vector.Iterator nodePredicateValues = cRFIndexedExample.getNodePredicateValues(i7);
            while (!nodePredicateValues.isExhausted()) {
                int index = (int) nodePredicateValues.index();
                double value = nodePredicateValues.value();
                for (int i8 = 0; i8 < size; i8++) {
                    vector2.inc(this.weightEncoder.nodeWeightIndex(index, i8), value * (-nodeMarginals[i7][i8]));
                }
                nodePredicateValues.advance();
            }
            Vector.Iterator edgePredicateValues2 = cRFIndexedExample.getEdgePredicateValues(i7);
            while (!edgePredicateValues2.isExhausted()) {
                int index2 = (int) edgePredicateValues2.index();
                double value2 = edgePredicateValues2.value();
                for (int i9 = 0; i9 < size2; i9++) {
                    vector2.inc(this.weightEncoder.edgeWeightIndex(index2, i9), value2 * (-edgeMarginals[i7][i9]));
                }
                edgePredicateValues2.advance();
            }
        }
        if ($assertionsDisabled || d <= logZ) {
            return d - logZ;
        }
        throw new AssertionError();
    }

    private void updateGrad(Vector vector, Vector.Iterator iterator, IntToIntFunction intToIntFunction, double d) {
        iterator.reset();
        while (!iterator.isExhausted()) {
            vector.inc(intToIntFunction.valueOf((int) iterator.index()), d * iterator.value());
            iterator.advance();
        }
    }

    @ConstructorProperties({"weightEncoder"})
    public CRFLogLikelihoodObjective(CRFWeightsEncoder<S> cRFWeightsEncoder) {
        this.weightEncoder = cRFWeightsEncoder;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1138276811:
                if (implMethodName.equals("lambda$evaluate$ac604bb2$1")) {
                    z = false;
                    break;
                }
                break;
            case -381316943:
                if (implMethodName.equals("lambda$evaluate$d9763b1d$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gs/collections/api/block/function/primitive/IntToIntFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("valueOf") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(I)I") && serializedLambda.getImplClass().equals("org/allenai/ml/sequences/crf/CRFLogLikelihoodObjective") && serializedLambda.getImplMethodSignature().equals("(II)I")) {
                    CRFLogLikelihoodObjective cRFLogLikelihoodObjective = (CRFLogLikelihoodObjective) serializedLambda.getCapturedArg(0);
                    int intValue = ((Integer) serializedLambda.getCapturedArg(1)).intValue();
                    return i6 -> {
                        return this.weightEncoder.edgeWeightIndex(i6, intValue);
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gs/collections/api/block/function/primitive/IntToIntFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("valueOf") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(I)I") && serializedLambda.getImplClass().equals("org/allenai/ml/sequences/crf/CRFLogLikelihoodObjective") && serializedLambda.getImplMethodSignature().equals("(II)I")) {
                    CRFLogLikelihoodObjective cRFLogLikelihoodObjective2 = (CRFLogLikelihoodObjective) serializedLambda.getCapturedArg(0);
                    int intValue2 = ((Integer) serializedLambda.getCapturedArg(1)).intValue();
                    return i4 -> {
                        return this.weightEncoder.nodeWeightIndex(i4, intValue2);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }

    static {
        $assertionsDisabled = !CRFLogLikelihoodObjective.class.desiredAssertionStatus();
    }
}
