/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class Reduce<NAMETYPE extends Name>
extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argument;
    private final List<String> dimensions;
    private final Aggregator aggregator;

    public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator) {
        this(argument, aggregator, List.of());
    }

    public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, String dimension) {
        this(argument, aggregator, List.of(dimension));
    }

    public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, List<String> dimensions) {
        this.argument = Objects.requireNonNull(argument, "The argument tensor cannot be null");
        this.aggregator = Objects.requireNonNull(aggregator, "The aggregator cannot be null");
        this.dimensions = List.copyOf(dimensions);
    }

    public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
        return TypeResolver.reduce(inputType, reduceDimensions);
    }

    public TensorFunction<NAMETYPE> argument() {
        return this.argument;
    }

    Aggregator aggregator() {
        return this.aggregator;
    }

    List<String> dimensions() {
        return this.dimensions;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argument);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 1) {
            throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
        }
        return new Reduce<NAMETYPE>(arguments.get(0), this.aggregator, this.dimensions);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new Reduce<NAMETYPE>(this.argument.toPrimitive(), this.aggregator, this.dimensions);
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "reduce(" + this.argument.toString(context) + ", " + this.aggregator + Reduce.commaSeparated(this.dimensions) + ")";
    }

    static String commaSeparated(List<String> list) {
        StringBuilder b = new StringBuilder();
        for (String element : list) {
            b.append(", ").append(element);
        }
        return b.toString();
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return Reduce.outputType(this.argument.type(context), this.dimensions);
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        return Reduce.evaluate(this.argument.evaluate(context), this.dimensions, this.aggregator);
    }

    @Override
    public int hashCode() {
        return Objects.hash(new Object[]{"reduce", this.argument, this.dimensions, this.aggregator});
    }

    static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) {
        if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions)) {
            throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor");
        }
        if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) {
            if (argument.isEmpty()) {
                return Tensor.from(0.0);
            }
            if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) {
                return Reduce.reduceIndexedVector((IndexedTensor)argument, aggregator);
            }
            return Reduce.reduceAllGeneral(argument, aggregator);
        }
        TensorType reducedType = Reduce.outputType(argument.type(), dimensions);
        HashMap<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<TensorAddress, ValueAggregator>();
        Iterator<Tensor.Cell> i = argument.cellIterator();
        while (i.hasNext()) {
            Map.Entry cell = i.next();
            TensorAddress reducedAddress = Reduce.reduceDimensions((TensorAddress)cell.getKey(), argument.type(), reducedType, dimensions);
            aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
            ((ValueAggregator)aggregatingCells.get(reducedAddress)).aggregate((Double)cell.getValue());
        }
        Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
        for (Map.Entry aggregatingCell : aggregatingCells.entrySet()) {
            reducedBuilder.cell((TensorAddress)aggregatingCell.getKey(), ((ValueAggregator)aggregatingCell.getValue()).aggregatedValue());
        }
        return reducedBuilder.build();
    }

    private static TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType, List<String> dimensions) {
        HashSet<Integer> indexesToRemove = new HashSet<Integer>();
        for (String dimensionToRemove : dimensions) {
            indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
        }
        String[] reducedLabels = new String[reducedType.dimensions().size()];
        int reducedLabelIndex = 0;
        for (int i = 0; i < address.size(); ++i) {
            if (indexesToRemove.contains(i)) continue;
            reducedLabels[reducedLabelIndex++] = address.label(i);
        }
        return TensorAddress.of(reducedLabels);
    }

    private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
        ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
        Iterator<Double> i = argument.valueIterator();
        while (i.hasNext()) {
            valueAggregator.aggregate(i.next());
        }
        return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue(), new long[0]).build();
    }

    private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) {
        ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
        int i = 0;
        while ((long)i < argument.dimensionSizes().size(0)) {
            valueAggregator.aggregate(argument.get((long)i));
            ++i;
        }
        return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue(), new long[0]).build();
    }

    public static enum Aggregator {
        avg,
        count,
        max,
        median,
        min,
        prod,
        sum;

    }

    static abstract class ValueAggregator {
        ValueAggregator() {
        }

        static ValueAggregator ofType(Aggregator aggregator) {
            return switch (aggregator) {
                case Aggregator.avg -> new AvgAggregator();
                case Aggregator.count -> new CountAggregator();
                case Aggregator.max -> new MaxAggregator();
                case Aggregator.median -> new MedianAggregator();
                case Aggregator.min -> new MinAggregator();
                case Aggregator.prod -> new ProdAggregator();
                case Aggregator.sum -> new SumAggregator();
                default -> throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
            };
        }

        public abstract void aggregate(double var1);

        public abstract double aggregatedValue();

        public abstract void reset();

        public abstract int hashCode();
    }

    private static class SumAggregator
    extends ValueAggregator {
        private double valueSum = 0.0;

        private SumAggregator() {
        }

        @Override
        public void aggregate(double value) {
            this.valueSum += value;
        }

        @Override
        public double aggregatedValue() {
            return this.valueSum;
        }

        @Override
        public void reset() {
            this.valueSum = 0.0;
        }

        @Override
        public int hashCode() {
            return "sumAggregator".hashCode();
        }
    }

    private static class ProdAggregator
    extends ValueAggregator {
        private double valueProd = 1.0;

        private ProdAggregator() {
        }

        @Override
        public void aggregate(double value) {
            this.valueProd *= value;
        }

        @Override
        public double aggregatedValue() {
            return this.valueProd;
        }

        @Override
        public void reset() {
            this.valueProd = 1.0;
        }

        @Override
        public int hashCode() {
            return "prodAggregator".hashCode();
        }
    }

    private static class MinAggregator
    extends ValueAggregator {
        private double minValue = Double.POSITIVE_INFINITY;

        private MinAggregator() {
        }

        @Override
        public void aggregate(double value) {
            if (value < this.minValue) {
                this.minValue = value;
            }
        }

        @Override
        public double aggregatedValue() {
            return this.minValue;
        }

        @Override
        public void reset() {
            this.minValue = Double.POSITIVE_INFINITY;
        }

        @Override
        public int hashCode() {
            return "minAggregator".hashCode();
        }
    }

    private static class MedianAggregator
    extends ValueAggregator {
        private boolean isNaN = false;
        private List<Double> values = new ArrayList<Double>();

        private MedianAggregator() {
        }

        @Override
        public void aggregate(double value) {
            if (Double.isNaN(value)) {
                this.isNaN = true;
            }
            if (!this.isNaN) {
                this.values.add(value);
            }
        }

        @Override
        public double aggregatedValue() {
            if (this.isNaN || this.values.isEmpty()) {
                return Double.NaN;
            }
            Collections.sort(this.values);
            if (this.values.size() % 2 == 0) {
                return (this.values.get(this.values.size() / 2 - 1) + this.values.get(this.values.size() / 2)) / 2.0;
            }
            return this.values.get((this.values.size() - 1) / 2);
        }

        @Override
        public void reset() {
            this.isNaN = false;
            this.values = new ArrayList<Double>();
        }

        @Override
        public int hashCode() {
            return "medianAggregator".hashCode();
        }
    }

    private static class MaxAggregator
    extends ValueAggregator {
        private double maxValue = Double.NEGATIVE_INFINITY;

        private MaxAggregator() {
        }

        @Override
        public void aggregate(double value) {
            if (value > this.maxValue) {
                this.maxValue = value;
            }
        }

        @Override
        public double aggregatedValue() {
            return this.maxValue;
        }

        @Override
        public void reset() {
            this.maxValue = Double.NEGATIVE_INFINITY;
        }

        @Override
        public int hashCode() {
            return "maxAggregator".hashCode();
        }
    }

    private static class CountAggregator
    extends ValueAggregator {
        private int valueCount = 0;

        private CountAggregator() {
        }

        @Override
        public void aggregate(double value) {
            ++this.valueCount;
        }

        @Override
        public double aggregatedValue() {
            return this.valueCount;
        }

        @Override
        public void reset() {
            this.valueCount = 0;
        }

        @Override
        public int hashCode() {
            return "countAggregator".hashCode();
        }
    }

    private static class AvgAggregator
    extends ValueAggregator {
        private int valueCount = 0;
        private double valueSum = 0.0;

        private AvgAggregator() {
        }

        @Override
        public void aggregate(double value) {
            ++this.valueCount;
            this.valueSum += value;
        }

        @Override
        public double aggregatedValue() {
            return this.valueSum / (double)this.valueCount;
        }

        @Override
        public void reset() {
            this.valueCount = 0;
            this.valueSum = 0.0;
        }

        @Override
        public int hashCode() {
            return "avgAggregator".hashCode();
        }
    }
}

