package io.improbable.keanu.tensor.jvm;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.jvm.Slicer;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;

/* loaded from: input_file:io/improbable/keanu/tensor/jvm/SlicerIndexMapper.class */
public final class SlicerIndexMapper implements IndexMapper {
    private final Slicer slicer;
    private final long[] sourceShape;
    private final long[] sourceStride;
    private final long[] resultShapeWithoutRankLoss;
    private final long[] resultStrideWithoutRankLoss;
    private final int[] dimensionsDropped;

    public SlicerIndexMapper(Slicer slicer, long[] jArr, long[] jArr2) {
        this.slicer = slicer;
        this.sourceShape = jArr;
        this.sourceStride = jArr2;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        initialize(jArr, arrayList, arrayList2, slicer.getSlices());
        this.resultShapeWithoutRankLoss = Longs.toArray(arrayList);
        this.resultStrideWithoutRankLoss = TensorShape.getRowFirstStride(this.resultShapeWithoutRankLoss);
        this.dimensionsDropped = Ints.toArray(arrayList2);
    }

    private static void initialize(long[] jArr, List<Long> list, List<Integer> list2, List<Slicer.StartStopStep> list3) {
        Preconditions.checkArgument(list3.size() <= jArr.length, "Too many slices specified for shape");
        for (int i = 0; i < jArr.length; i++) {
            if (i >= list3.size() || list3.get(i) == Slicer.StartStopStep.ALL) {
                list.add(Long.valueOf(jArr[i]));
            } else {
                Slicer.StartStopStep startStopStep = list3.get(i);
                if (startStopStep.getStop() == -2) {
                    list.add(1L);
                    list2.add(Integer.valueOf(i));
                } else {
                    list.add(Long.valueOf(1 + ((((startStopStep.getStop() == -1 ? jArr[i] : startStopStep.getStop()) - 1) - startStopStep.getStart()) / Math.abs(startStopStep.getStep()))));
                }
            }
        }
    }

    @Override // io.improbable.keanu.tensor.jvm.IndexMapper
    public long[] getResultShape() {
        return this.dimensionsDropped.length > 0 ? ArrayUtils.removeAll(this.resultShapeWithoutRankLoss, this.dimensionsDropped) : this.resultShapeWithoutRankLoss;
    }

    @Override // io.improbable.keanu.tensor.jvm.IndexMapper
    public long[] getResultStride() {
        return this.dimensionsDropped.length > 0 ? ArrayUtils.removeAll(this.resultStrideWithoutRankLoss, this.dimensionsDropped) : this.resultStrideWithoutRankLoss;
    }

    @Override // io.improbable.keanu.tensor.jvm.IndexMapper
    public long getSourceIndexFromResultIndex(long j) {
        return TensorShape.getFlatIndex(this.sourceShape, this.sourceStride, getIndicesOfSource(TensorShape.getShapeIndices(this.resultShapeWithoutRankLoss, this.resultStrideWithoutRankLoss, j)));
    }

    private long[] getIndicesOfSource(long[] jArr) {
        List<Slicer.StartStopStep> slices = this.slicer.getSlices();
        for (int i = 0; i < slices.size(); i++) {
            if (this.resultShapeWithoutRankLoss[i] != this.sourceShape[i]) {
                Slicer.StartStopStep startStopStep = slices.get(i);
                jArr[i] = startStopStep.getStart() + (jArr[i] * startStopStep.getStep());
            }
        }
        return jArr;
    }
}
