package org.allenai.ml.sequences.crf;

import com.gs.collections.api.list.primitive.ImmutableDoubleList;
import com.gs.collections.api.list.primitive.ImmutableIntList;
import com.gs.collections.api.list.primitive.IntList;
import com.gs.collections.api.list.primitive.MutableDoubleList;
import com.gs.collections.api.list.primitive.MutableIntList;
import com.gs.collections.impl.list.mutable.primitive.DoubleArrayList;
import com.gs.collections.impl.list.mutable.primitive.IntArrayList;
import java.beans.ConstructorProperties;
import java.util.List;
import org.allenai.ml.linalg.Vector;

/* loaded from: input_file:org/allenai/ml/sequences/crf/CRFIndexedExample.class */
public class CRFIndexedExample {
    private final ImmutableIntList allPredicateIndices;
    private final ImmutableDoubleList allPredicateValues;
    private final ImmutableIntList offsets;
    private final int sequenceLength;
    private final ImmutableIntList goldLabels;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/allenai/ml/sequences/crf/CRFIndexedExample$Iterator.class */
    private class Iterator implements Vector.Iterator {
        private final int start;
        private final int stop;
        private int offset = 0;

        @Override // org.allenai.ml.linalg.Vector.Iterator
        public boolean isExhausted() {
            return this.offset >= this.stop - this.start;
        }

        @Override // org.allenai.ml.linalg.Vector.Iterator
        public void advance() {
            this.offset++;
        }

        private void ensureNotExhausted() {
            if (isExhausted()) {
                throw new RuntimeException("Iterator is exhausted");
            }
        }

        @Override // org.allenai.ml.linalg.Vector.Iterator
        public long index() {
            ensureNotExhausted();
            return CRFIndexedExample.this.allPredicateIndices.get(this.start + this.offset);
        }

        @Override // org.allenai.ml.linalg.Vector.Iterator
        public double value() {
            ensureNotExhausted();
            return CRFIndexedExample.this.allPredicateValues.get(this.start + this.offset);
        }

        @Override // org.allenai.ml.linalg.Vector.Iterator
        public void reset() {
            this.offset = 0;
        }

        @ConstructorProperties({"start", "stop"})
        public Iterator(int i, int i2) {
            this.start = i;
            this.stop = i2;
        }
    }

    public CRFIndexedExample(List<Vector> list, List<Vector> list2, int[] iArr) {
        if (!$assertionsDisabled && list.size() != list2.size() + 1) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && iArr != null && iArr.length != list.size()) {
            throw new AssertionError();
        }
        IntArrayList intArrayList = new IntArrayList();
        DoubleArrayList doubleArrayList = new DoubleArrayList();
        IntList flattenPredicates = flattenPredicates(list, intArrayList, doubleArrayList, 0);
        IntList flattenPredicates2 = flattenPredicates(list2, intArrayList, doubleArrayList, intArrayList.size());
        IntArrayList intArrayList2 = new IntArrayList(flattenPredicates.size() + flattenPredicates2.size());
        intArrayList2.addAll(flattenPredicates);
        intArrayList2.addAll(flattenPredicates2);
        this.sequenceLength = list.size();
        this.offsets = intArrayList2.toImmutable();
        this.allPredicateIndices = intArrayList.toImmutable();
        this.allPredicateValues = doubleArrayList.toImmutable();
        this.goldLabels = iArr != null ? new IntArrayList(iArr).toImmutable() : null;
    }

    public CRFIndexedExample(List<Vector> list, List<Vector> list2) {
        this(list, list2, null);
    }

    private IntList flattenPredicates(List<Vector> list, MutableIntList mutableIntList, MutableDoubleList mutableDoubleList, int i) {
        IntArrayList intArrayList = new IntArrayList(list.size());
        int i2 = i;
        for (int i3 = 0; i3 < list.size(); i3++) {
            intArrayList.add(i2);
            Vector.Iterator it = list.get(i3).iterator();
            while (!it.isExhausted()) {
                if (it.value() != 0.0d) {
                    mutableIntList.add((int) it.index());
                    mutableDoubleList.add(it.value());
                    i2++;
                }
                it.advance();
            }
        }
        return intArrayList;
    }

    public int[] getGoldLabels() {
        return this.goldLabels.toArray();
    }

    public boolean isLabeled() {
        return this.goldLabels != null;
    }

    public Vector.Iterator getNodePredicateValues(int i) {
        if (i >= getSequenceLength()) {
            throw new IllegalArgumentException("Invalid node predicate index");
        }
        return new Iterator(this.offsets.get(i), this.offsets.get(i + 1));
    }

    public Vector.Iterator getEdgePredicateValues(int i) {
        if (i >= getSequenceLength() - 1) {
            throw new IllegalArgumentException("Invalid node transition edge index");
        }
        int i2 = this.offsets.get(getSequenceLength() + i);
        int sequenceLength = getSequenceLength() + i + 1;
        return new Iterator(i2, sequenceLength < this.offsets.size() ? this.offsets.get(sequenceLength) : this.allPredicateIndices.size());
    }

    public int getSequenceLength() {
        return this.sequenceLength;
    }

    static {
        $assertionsDisabled = !CRFIndexedExample.class.desiredAssertionStatus();
    }
}
