/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.ndarray.index;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.index.NDIndexAll;
import ai.djl.ndarray.index.NDIndexBooleans;
import ai.djl.ndarray.index.NDIndexElement;
import ai.djl.ndarray.index.NDIndexFixed;
import ai.djl.ndarray.index.NDIndexFullSlice;
import ai.djl.ndarray.index.NDIndexSlice;
import ai.djl.ndarray.types.Shape;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Stream;

public class NDIndex {
    private static final Pattern ITEM_PATTERN = Pattern.compile("(\\*)|((-?\\d+)?:(-?\\d+)?(:(-?\\d+))?)|(-?\\d+)");
    private int rank = 0;
    private List<NDIndexElement> indices = new ArrayList<NDIndexElement>();

    public NDIndex() {
    }

    public NDIndex(String indices) {
        this();
        this.addIndices(indices);
    }

    public NDIndex(long ... indices) {
        this();
        this.addIndices(indices);
    }

    public int getRank() {
        return this.rank;
    }

    public NDIndexElement get(int dimension) {
        return this.indices.get(dimension);
    }

    public List<NDIndexElement> getIndices() {
        return this.indices;
    }

    public final NDIndex addIndices(String indices) {
        String[] indexItems = indices.split(",");
        this.rank += indexItems.length;
        for (String indexItem : indexItems) {
            this.addIndexItem(indexItem);
        }
        return this;
    }

    public final NDIndex addIndices(long ... indices) {
        this.rank += indices.length;
        for (long i : indices) {
            this.indices.add(new NDIndexFixed(i));
        }
        return this;
    }

    public NDIndex addBooleanIndex(NDArray index) {
        this.rank += index.getShape().dimension();
        this.indices.add(new NDIndexBooleans(index));
        return this;
    }

    public NDIndex addSliceDim(long min, long max) {
        ++this.rank;
        this.indices.add(new NDIndexSlice(min, max, null));
        return this;
    }

    public NDIndex addSliceDim(long min, long max, long step) {
        ++this.rank;
        this.indices.add(new NDIndexSlice(min, max, step));
        return this;
    }

    public Stream<NDIndexElement> stream() {
        return this.indices.stream();
    }

    private void addIndexItem(String indexItem) {
        Long step;
        Matcher m = ITEM_PATTERN.matcher(indexItem = indexItem.trim());
        if (!m.matches()) {
            throw new IllegalArgumentException("Invalid argument index: " + indexItem);
        }
        String star = m.group(1);
        if (star != null) {
            this.indices.add(new NDIndexAll());
            return;
        }
        String digit = m.group(7);
        if (digit != null) {
            this.indices.add(new NDIndexFixed(Long.parseLong(digit)));
            return;
        }
        Long min = m.group(3) != null ? Long.valueOf(Long.parseLong(m.group(3))) : null;
        Long max = m.group(4) != null ? Long.valueOf(Long.parseLong(m.group(4))) : null;
        Long l = step = m.group(6) != null ? Long.valueOf(Long.parseLong(m.group(6))) : null;
        if (min == null && max == null && step == null) {
            this.indices.add(new NDIndexAll());
        } else {
            this.indices.add(new NDIndexSlice(min, max, step));
        }
    }

    public Optional<NDIndexFullSlice> getAsFullSlice(Shape target) {
        int i;
        if (!this.stream().allMatch(ie -> ie instanceof NDIndexAll || ie instanceof NDIndexFixed || ie instanceof NDIndexSlice)) {
            return Optional.empty();
        }
        int indDimensions = this.getRank();
        int targetDimensions = target.dimension();
        if (indDimensions > target.dimension()) {
            throw new IllegalArgumentException("The index has too many dimensions - " + indDimensions + " dimensions for array with " + targetDimensions + " dimensions");
        }
        long[] min = new long[targetDimensions];
        long[] max = new long[targetDimensions];
        long[] step = new long[targetDimensions];
        ArrayList<Integer> toSqueeze = new ArrayList<Integer>(targetDimensions);
        long[] shape = new long[targetDimensions];
        ArrayList<Long> squeezedShape = new ArrayList<Long>(targetDimensions);
        for (i = 0; i < indDimensions; ++i) {
            NDIndexElement ie2 = this.get(i);
            if (ie2 instanceof NDIndexFixed) {
                min[i] = ((NDIndexFixed)ie2).getIndex();
                max[i] = ((NDIndexFixed)ie2).getIndex() + 1L;
                step[i] = 1L;
                toSqueeze.add(i);
                shape[i] = 1L;
                continue;
            }
            if (ie2 instanceof NDIndexSlice) {
                NDIndexSlice slice = (NDIndexSlice)ie2;
                min[i] = Optional.ofNullable(slice.getMin()).orElse(0L);
                max[i] = Optional.ofNullable(slice.getMax()).orElse(target.size(i));
                step[i] = Optional.ofNullable(slice.getStep()).orElse(1L);
                shape[i] = step[i] > 0L ? (max[i] - min[i] - 1L) / (step[i] + 1L) : (min[i] - max[i]) / (-step[i] + 1L);
                squeezedShape.add(shape[i]);
                continue;
            }
            if (!(ie2 instanceof NDIndexAll)) continue;
            min[i] = 0L;
            max[i] = target.size(i);
            step[i] = 1L;
            shape[i] = target.size(i);
            squeezedShape.add(target.size(i));
        }
        i = indDimensions;
        while (i < target.dimension()) {
            min[i] = 0L;
            max[i] = target.size(i);
            step[i] = 1L;
            shape[i] = target.size(i);
            squeezedShape.add(target.size(i++));
        }
        NDIndexFullSlice fullSlice = new NDIndexFullSlice(min, max, step, toSqueeze, new Shape(shape), new Shape(squeezedShape));
        return Optional.of(fullSlice);
    }
}

