/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.google.common.annotations.Beta;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

@Beta
public class Slice<NAMETYPE extends Name>
extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argument;
    private final List<DimensionValue<NAMETYPE>> subspaceAddress;

    public Slice(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> subspaceAddress) {
        this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
        if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty())) {
            throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: Specify dimension names explicitly instead");
        }
        this.subspaceAddress = subspaceAddress;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argument);
    }

    @Override
    public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 1) {
            throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
        }
        return new Slice<NAMETYPE>(arguments.get(0), this.subspaceAddress);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return this;
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor tensor = this.argument.evaluate(context);
        TensorType resultType = this.resultType(tensor.type());
        PartialAddress subspaceAddress = this.subspaceToAddress(tensor.type(), context);
        if (resultType.rank() == 0) {
            return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type())));
        }
        Tensor.Builder b = Tensor.Builder.of(resultType);
        Iterator<Tensor.Cell> i = tensor.cellIterator();
        while (i.hasNext()) {
            Tensor.Cell cell = i.next();
            if (!this.matches(subspaceAddress, cell.getKey(), tensor.type())) continue;
            b.cell(this.remaining(resultType, cell.getKey(), tensor.type()), (double)cell.getValue());
        }
        return b.build();
    }

    private PartialAddress subspaceToAddress(TensorType type, EvaluationContext<NAMETYPE> context) {
        PartialAddress.Builder b = new PartialAddress.Builder(this.subspaceAddress.size());
        for (int i = 0; i < this.subspaceAddress.size(); ++i) {
            if (this.subspaceAddress.get(i).label().isPresent()) {
                b.add(this.subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), this.subspaceAddress.get(i).label().get());
                continue;
            }
            b.add(this.subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), this.subspaceAddress.get(i).index().get().apply(context).intValue());
        }
        return b.build();
    }

    private boolean matches(PartialAddress subspaceAddress, TensorAddress address, TensorType type) {
        for (int i = 0; i < subspaceAddress.size(); ++i) {
            String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get());
            if (label.equals(subspaceAddress.label(i))) continue;
            return false;
        }
        return true;
    }

    private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) {
        TensorAddress.Builder b = new TensorAddress.Builder(subspaceType);
        for (int i = 0; i < address.size(); ++i) {
            String dimension = type.dimensions().get(i).name();
            if (!subspaceType.dimension(type.dimensions().get(i).name()).isPresent()) continue;
            b.add(dimension, address.label(i));
        }
        return b.build();
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return this.resultType(this.argument.type(context));
    }

    private TensorType resultType(TensorType argumentType) {
        TensorType.Builder b = new TensorType.Builder();
        if (this.subspaceAddress.size() == 1 && this.subspaceAddress.get(0).dimension().isEmpty()) {
            if (this.subspaceAddress.get(0).index().isPresent()) {
                if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1L) {
                    throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied  to " + argumentType + ", which have multiple");
                }
                for (TensorType.Dimension dimension : argumentType.dimensions()) {
                    if (dimension.isIndexed()) continue;
                    b.dimension(dimension);
                }
            } else {
                if (argumentType.dimensions().stream().filter(d -> !d.isIndexed()).count() > 1L) {
                    throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied  to " + argumentType + ", which have multiple");
                }
                for (TensorType.Dimension dimension : argumentType.dimensions()) {
                    if (!dimension.isIndexed()) continue;
                    b.dimension(dimension);
                }
            }
        } else {
            Set slicedDimensions = this.subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet());
            for (TensorType.Dimension dimension : argumentType.dimensions()) {
                if (slicedDimensions.contains(dimension.name())) {
                    slicedDimensions.remove(dimension.name());
                    continue;
                }
                b.dimension(dimension);
            }
            if (!slicedDimensions.isEmpty()) {
                throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " + argumentType);
            }
        }
        return b.build();
    }

    @Override
    public String toString(ToStringContext context) {
        StringBuilder b = new StringBuilder(this.argument.toString(context));
        if (this.subspaceAddress.size() == 1 && this.subspaceAddress.get(0).dimension().isEmpty()) {
            if (this.subspaceAddress.get(0).index().isPresent()) {
                b.append("[").append(this.subspaceAddress.get(0).index().get().toString(context)).append("]");
            } else {
                b.append("{").append(this.subspaceAddress.get(0).label().get()).append("}");
            }
        } else {
            b.append("{").append(this.subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}");
        }
        return b.toString();
    }

    private static class ConstantIntegerFunction<NAMETYPE extends Name>
    implements ScalarFunction<NAMETYPE> {
        private final int value;

        public ConstantIntegerFunction(int value) {
            this.value = value;
        }

        @Override
        public Double apply(EvaluationContext<NAMETYPE> context) {
            return this.value;
        }

        public String toString() {
            return String.valueOf(this.value);
        }
    }

    public static class DimensionValue<NAMETYPE extends Name> {
        private final Optional<String> dimension;
        private final String label;
        private final ScalarFunction<NAMETYPE> index;

        public DimensionValue(String dimension, String label) {
            this(Optional.of(dimension), label, null);
        }

        public DimensionValue(String dimension, int index) {
            this(Optional.of(dimension), null, new ConstantIntegerFunction(index));
        }

        public DimensionValue(int index) {
            this(Optional.empty(), null, new ConstantIntegerFunction(index));
        }

        public DimensionValue(String label) {
            this(Optional.empty(), label, null);
        }

        public DimensionValue(ScalarFunction<NAMETYPE> index) {
            this(Optional.empty(), null, index);
        }

        public DimensionValue(Optional<String> dimension, String label) {
            this(dimension, label, null);
        }

        public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) {
            this(dimension, null, index);
        }

        public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) {
            this(Optional.of(dimension), null, index);
        }

        private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) {
            this.dimension = dimension;
            this.label = label;
            this.index = index;
        }

        public Optional<String> dimension() {
            return this.dimension;
        }

        public Optional<String> label() {
            return Optional.ofNullable(this.label);
        }

        public Optional<ScalarFunction<NAMETYPE>> index() {
            return Optional.ofNullable(this.index);
        }

        public String toString() {
            return this.toString(null);
        }

        public String toString(ToStringContext context) {
            StringBuilder b = new StringBuilder();
            this.dimension.ifPresent(d -> b.append((String)d).append(":"));
            if (this.label != null) {
                b.append(this.label);
            } else {
                b.append(this.index.toString(context));
            }
            return b.toString();
        }
    }
}

