/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor.jvm;

import com.google.common.base.Preconditions;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.jvm.IndexMapper;
import io.improbable.keanu.tensor.jvm.Slicer;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;

public final class SlicerIndexMapper
implements IndexMapper {
    private final Slicer slicer;
    private final long[] sourceShape;
    private final long[] sourceStride;
    private final long[] resultShapeWithoutRankLoss;
    private final long[] resultStrideWithoutRankLoss;
    private final int[] dimensionsDropped;

    public SlicerIndexMapper(Slicer slicer, long[] sourceShape, long[] sourceStride) {
        this.slicer = slicer;
        this.sourceShape = sourceShape;
        this.sourceStride = sourceStride;
        ArrayList<Long> shapeList = new ArrayList<Long>();
        ArrayList<Integer> droppedList = new ArrayList<Integer>();
        List<Slicer.StartStopStep> slices = slicer.getSlices();
        SlicerIndexMapper.initialize(sourceShape, shapeList, droppedList, slices);
        this.resultShapeWithoutRankLoss = Longs.toArray(shapeList);
        this.resultStrideWithoutRankLoss = TensorShape.getRowFirstStride(this.resultShapeWithoutRankLoss);
        this.dimensionsDropped = Ints.toArray(droppedList);
    }

    private static void initialize(long[] sourceShape, List<Long> shapeList, List<Integer> droppedList, List<Slicer.StartStopStep> slices) {
        Preconditions.checkArgument((slices.size() <= sourceShape.length ? 1 : 0) != 0, (Object)"Too many slices specified for shape");
        for (int i = 0; i < sourceShape.length; ++i) {
            if (i >= slices.size() || slices.get(i) == Slicer.StartStopStep.ALL) {
                shapeList.add(sourceShape[i]);
                continue;
            }
            Slicer.StartStopStep slice = slices.get(i);
            if (slice.getStop() == -2L) {
                shapeList.add(1L);
                droppedList.add(i);
                continue;
            }
            long stop = slice.getStop() == -1L ? sourceShape[i] : slice.getStop();
            long absStep = Math.abs(slice.getStep());
            long length = 1L + (stop - 1L - slice.getStart()) / absStep;
            shapeList.add(length);
        }
    }

    @Override
    public long[] getResultShape() {
        return this.dimensionsDropped.length > 0 ? ArrayUtils.removeAll((long[])this.resultShapeWithoutRankLoss, (int[])this.dimensionsDropped) : this.resultShapeWithoutRankLoss;
    }

    @Override
    public long[] getResultStride() {
        return this.dimensionsDropped.length > 0 ? ArrayUtils.removeAll((long[])this.resultStrideWithoutRankLoss, (int[])this.dimensionsDropped) : this.resultStrideWithoutRankLoss;
    }

    @Override
    public long getSourceIndexFromResultIndex(long resultIndex) {
        long[] shapeIndices = TensorShape.getShapeIndices(this.resultShapeWithoutRankLoss, this.resultStrideWithoutRankLoss, resultIndex);
        long[] sourceIndices = this.getIndicesOfSource(shapeIndices);
        return TensorShape.getFlatIndex(this.sourceShape, this.sourceStride, sourceIndices);
    }

    private long[] getIndicesOfSource(long[] shapeIndices) {
        List<Slicer.StartStopStep> slices = this.slicer.getSlices();
        for (int i = 0; i < slices.size(); ++i) {
            if (this.resultShapeWithoutRankLoss[i] == this.sourceShape[i]) continue;
            Slicer.StartStopStep slice = slices.get(i);
            shapeIndices[i] = slice.getStart() + shapeIndices[i] * slice.getStep();
        }
        return shapeIndices;
    }
}

