/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class Reduce
extends PrimitiveTensorFunction {
    private final TensorFunction argument;
    private final List<String> dimensions;
    private final Aggregator aggregator;

    public Reduce(TensorFunction argument, Aggregator aggregator) {
        this(argument, aggregator, Collections.emptyList());
    }

    public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
        this(argument, aggregator, Collections.singletonList(dimension));
    }

    public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) {
        Objects.requireNonNull(argument, "The argument tensor cannot be null");
        Objects.requireNonNull(aggregator, "The aggregator cannot be null");
        Objects.requireNonNull(dimensions, "The dimensions cannot be null");
        this.argument = argument;
        this.aggregator = aggregator;
        this.dimensions = ImmutableList.copyOf(dimensions);
    }

    public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
        if (reduceDimensions.isEmpty()) {
            return TensorType.empty;
        }
        TensorType.Builder b = new TensorType.Builder();
        for (TensorType.Dimension dimension : inputType.dimensions()) {
            if (reduceDimensions.contains(dimension.name())) continue;
            b.dimension(dimension);
        }
        return b.build();
    }

    public TensorFunction argument() {
        return this.argument;
    }

    Aggregator aggregator() {
        return this.aggregator;
    }

    List<String> dimensions() {
        return this.dimensions;
    }

    @Override
    public List<TensorFunction> arguments() {
        return Collections.singletonList(this.argument);
    }

    @Override
    public TensorFunction withArguments(List<TensorFunction> arguments) {
        if (arguments.size() != 1) {
            throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
        }
        return new Reduce(arguments.get(0), this.aggregator, this.dimensions);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() {
        return new Reduce((TensorFunction)this.argument.toPrimitive(), this.aggregator, this.dimensions);
    }

    @Override
    public String toString(ToStringContext context) {
        return "reduce(" + this.argument.toString(context) + ", " + this.aggregator + Reduce.commaSeparated(this.dimensions) + ")";
    }

    static String commaSeparated(List<String> list) {
        StringBuilder b = new StringBuilder();
        for (String element : list) {
            b.append(", ").append(element);
        }
        return b.toString();
    }

    @Override
    public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
        return Reduce.type(this.argument.type(context), this.dimensions);
    }

    private static TensorType type(TensorType argumentType, List<String> dimensions) {
        if (dimensions.isEmpty()) {
            return TensorType.empty;
        }
        TensorType.Builder builder = new TensorType.Builder();
        for (TensorType.Dimension dimension : argumentType.dimensions()) {
            if (dimensions.contains(dimension.name())) continue;
            builder.dimension(dimension);
        }
        return builder.build();
    }

    @Override
    public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        return Reduce.evaluate(this.argument.evaluate(context), this.dimensions, this.aggregator);
    }

    static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) {
        if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions)) {
            throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor");
        }
        if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) {
            if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) {
                return Reduce.reduceIndexedVector((IndexedTensor)argument, aggregator);
            }
            return Reduce.reduceAllGeneral(argument, aggregator);
        }
        TensorType reducedType = Reduce.type(argument.type(), dimensions);
        HashMap<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<TensorAddress, ValueAggregator>();
        Iterator<Tensor.Cell> i = argument.cellIterator();
        while (i.hasNext()) {
            Map.Entry cell = i.next();
            TensorAddress reducedAddress = Reduce.reduceDimensions((TensorAddress)cell.getKey(), argument.type(), reducedType, dimensions);
            aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
            ((ValueAggregator)aggregatingCells.get(reducedAddress)).aggregate((Double)cell.getValue());
        }
        Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
        for (Map.Entry aggregatingCell : aggregatingCells.entrySet()) {
            reducedBuilder.cell((TensorAddress)aggregatingCell.getKey(), ((ValueAggregator)aggregatingCell.getValue()).aggregatedValue());
        }
        return reducedBuilder.build();
    }

    private static TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType, List<String> dimensions) {
        HashSet<Integer> indexesToRemove = new HashSet<Integer>();
        for (String dimensionToRemove : dimensions) {
            indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
        }
        String[] reducedLabels = new String[reducedType.dimensions().size()];
        int reducedLabelIndex = 0;
        for (int i = 0; i < address.size(); ++i) {
            if (indexesToRemove.contains(i)) continue;
            reducedLabels[reducedLabelIndex++] = address.label(i);
        }
        return TensorAddress.of(reducedLabels);
    }

    private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
        ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
        Iterator<Double> i = argument.valueIterator();
        while (i.hasNext()) {
            valueAggregator.aggregate(i.next());
        }
        return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue(), new long[0]).build();
    }

    private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) {
        ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
        int i = 0;
        while ((long)i < argument.dimensionSizes().size(0)) {
            valueAggregator.aggregate(argument.get((long)i));
            ++i;
        }
        return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue(), new long[0]).build();
    }

    private static class MinAggregator
    extends ValueAggregator {
        private double minValue = Double.MAX_VALUE;

        private MinAggregator() {
        }

        @Override
        public void aggregate(double value) {
            if (value < this.minValue) {
                this.minValue = value;
            }
        }

        @Override
        public double aggregatedValue() {
            return this.minValue;
        }

        @Override
        public void reset() {
            this.minValue = Double.MAX_VALUE;
        }
    }

    private static class MaxAggregator
    extends ValueAggregator {
        private double maxValue = Double.MIN_VALUE;

        private MaxAggregator() {
        }

        @Override
        public void aggregate(double value) {
            if (value > this.maxValue) {
                this.maxValue = value;
            }
        }

        @Override
        public double aggregatedValue() {
            return this.maxValue;
        }

        @Override
        public void reset() {
            this.maxValue = Double.MIN_VALUE;
        }
    }

    private static class SumAggregator
    extends ValueAggregator {
        private double valueSum = 0.0;

        private SumAggregator() {
        }

        @Override
        public void aggregate(double value) {
            this.valueSum += value;
        }

        @Override
        public double aggregatedValue() {
            return this.valueSum;
        }

        @Override
        public void reset() {
            this.valueSum = 0.0;
        }
    }

    private static class ProdAggregator
    extends ValueAggregator {
        private double valueProd = 1.0;

        private ProdAggregator() {
        }

        @Override
        public void aggregate(double value) {
            this.valueProd *= value;
        }

        @Override
        public double aggregatedValue() {
            return this.valueProd;
        }

        @Override
        public void reset() {
            this.valueProd = 1.0;
        }
    }

    private static class CountAggregator
    extends ValueAggregator {
        private int valueCount = 0;

        private CountAggregator() {
        }

        @Override
        public void aggregate(double value) {
            ++this.valueCount;
        }

        @Override
        public double aggregatedValue() {
            return this.valueCount;
        }

        @Override
        public void reset() {
            this.valueCount = 0;
        }
    }

    private static class AvgAggregator
    extends ValueAggregator {
        private int valueCount = 0;
        private double valueSum = 0.0;

        private AvgAggregator() {
        }

        @Override
        public void aggregate(double value) {
            ++this.valueCount;
            this.valueSum += value;
        }

        @Override
        public double aggregatedValue() {
            return this.valueSum / (double)this.valueCount;
        }

        @Override
        public void reset() {
            this.valueCount = 0;
            this.valueSum = 0.0;
        }
    }

    static abstract class ValueAggregator {
        ValueAggregator() {
        }

        static ValueAggregator ofType(Aggregator aggregator) {
            switch (aggregator) {
                case avg: {
                    return new AvgAggregator();
                }
                case count: {
                    return new CountAggregator();
                }
                case prod: {
                    return new ProdAggregator();
                }
                case sum: {
                    return new SumAggregator();
                }
                case max: {
                    return new MaxAggregator();
                }
                case min: {
                    return new MinAggregator();
                }
            }
            throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
        }

        public abstract void aggregate(double var1);

        public abstract double aggregatedValue();

        public abstract void reset();
    }

    public static enum Aggregator {
        avg,
        count,
        prod,
        sum,
        max,
        min;

    }
}

