package org.allenai.ml.sequences.crf;

import java.beans.ConstructorProperties;
import java.util.List;
import org.allenai.ml.linalg.Vector;
import org.allenai.ml.sequences.StateSpace;
import org.allenai.ml.sequences.Transition;

/* loaded from: input_file:org/allenai/ml/sequences/crf/CRFWeightsEncoder.class */
public class CRFWeightsEncoder<S> {
    public final StateSpace<S> stateSpace;
    public final int numNodePredicates;
    public final int numEdgePredicates;

    public int numParameters() {
        return (this.numNodePredicates * this.stateSpace.states().size()) + (this.numEdgePredicates * this.stateSpace.transitions().size());
    }

    static double[] fillRowPotentials(Vector vector, Vector.Iterator iterator, int i, int i2) {
        double[] dArr = new double[i];
        while (!iterator.isExhausted()) {
            long index = iterator.index();
            double value = iterator.value();
            long j = (index * i) + i2;
            for (int i3 = 0; i3 < i; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + (vector.at(j + i3) * value);
            }
            iterator.advance();
        }
        return dArr;
    }

    private double[] nodePotentials(Vector vector, Vector.Iterator iterator) {
        return fillRowPotentials(vector, iterator, this.stateSpace.states().size(), 0);
    }

    private double[] edgePotentials(Vector vector, Vector.Iterator iterator) {
        List<Transition> transitions = this.stateSpace.transitions();
        return fillRowPotentials(vector, iterator, transitions.size(), this.numNodePredicates);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double[][] fillPotentials(Vector vector, CRFIndexedExample cRFIndexedExample) {
        int sequenceLength = cRFIndexedExample.getSequenceLength() - 1;
        List<Transition> transitions = this.stateSpace.transitions();
        double[][] dArr = new double[sequenceLength][transitions.size()];
        for (int i = 0; i < sequenceLength; i++) {
            double[] nodePotentials = nodePotentials(vector, cRFIndexedExample.getNodePredicateValues(i));
            double[] edgePotentials = edgePotentials(vector, cRFIndexedExample.getEdgePredicateValues(i));
            for (int i2 = 0; i2 < transitions.size(); i2++) {
                dArr[i][i2] = edgePotentials[i2] + nodePotentials[transitions.get(i2).fromState];
            }
        }
        return dArr;
    }

    public int nodeWeightIndex(int i, int i2) {
        return (i * this.stateSpace.states().size()) + i2;
    }

    public int edgeWeightIndex(int i, int i2) {
        return this.numNodePredicates + (i * this.stateSpace.transitions().size()) + i2;
    }

    @ConstructorProperties({"stateSpace", "numNodePredicates", "numEdgePredicates"})
    public CRFWeightsEncoder(StateSpace<S> stateSpace, int i, int i2) {
        this.stateSpace = stateSpace;
        this.numNodePredicates = i;
        this.numEdgePredicates = i2;
    }
}
