/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.CompositeTensorFunction;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.Arrays;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
import java.util.stream.Collectors;

public class ReduceJoin
extends CompositeTensorFunction {
    private final TensorFunction argumentA;
    private final TensorFunction argumentB;
    private final DoubleBinaryOperator combinator;
    private final Reduce.Aggregator aggregator;
    private final List<String> dimensions;

    public ReduceJoin(Reduce reduce, Join join) {
        this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions());
    }

    public ReduceJoin(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator, Reduce.Aggregator aggregator, List<String> dimensions) {
        this.argumentA = argumentA;
        this.argumentB = argumentB;
        this.combinator = combinator;
        this.aggregator = aggregator;
        this.dimensions = ImmutableList.copyOf(dimensions);
    }

    @Override
    public List<TensorFunction> arguments() {
        return ImmutableList.of((Object)this.argumentA, (Object)this.argumentB);
    }

    @Override
    public TensorFunction withArguments(List<TensorFunction> arguments) {
        if (arguments.size() != 2) {
            throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size());
        }
        return new ReduceJoin(arguments.get(0), arguments.get(1), this.combinator, this.aggregator, this.dimensions);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() {
        Join join = new Join(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.combinator);
        return new Reduce((TensorFunction)join, this.aggregator, this.dimensions);
    }

    @Override
    public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor a = this.argumentA.evaluate(context);
        Tensor b = this.argumentB.evaluate(context);
        TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
        if (this.canOptimize(a, b)) {
            return this.evaluate((IndexedTensor)a, (IndexedTensor)b, joinedType);
        }
        return Reduce.evaluate(Join.evaluate(a, b, joinedType, this.combinator), this.dimensions, this.aggregator);
    }

    public boolean canOptimize(Tensor a, Tensor b) {
        if (a.type().dimensions().isEmpty() || b.type().dimensions().isEmpty()) {
            return false;
        }
        if (!(a instanceof IndexedTensor)) {
            return false;
        }
        if (!a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) {
            return false;
        }
        if (!(b instanceof IndexedTensor)) {
            return false;
        }
        if (!b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) {
            return false;
        }
        TensorType commonDimensions = this.dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b);
        if (this.dimensions.isEmpty()) {
            if (a.type().dimensions().size() != commonDimensions.dimensions().size()) {
                return false;
            }
            if (b.type().dimensions().size() != commonDimensions.dimensions().size()) {
                return false;
            }
        } else {
            for (TensorType.Dimension dimension : commonDimensions.dimensions()) {
                if (this.dimensions.contains(dimension.name())) continue;
                return false;
            }
        }
        return true;
    }

    private Tensor evaluate(IndexedTensor a, IndexedTensor b, TensorType joinedType) {
        TensorType reducedType = Reduce.outputType(joinedType, this.dimensions);
        if (this.reduceDimensionIsInnermost(a, b)) {
            if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 1) {
                return this.vectorVectorProduct(a, b, reducedType);
            }
            if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 2) {
                return this.vectorMatrixProduct(a, b, reducedType, false);
            }
            if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 1) {
                return this.vectorMatrixProduct(b, a, reducedType, true);
            }
            if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 2) {
                return this.matrixMatrixProduct(a, b, reducedType);
            }
        }
        return this.evaluateGeneral(a, b, reducedType);
    }

    private Tensor vectorVectorProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
        if (a.type().dimensions().size() != 1 || b.type().dimensions().size() != 1) {
            throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-vector product");
        }
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
        long commonSize = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
        Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(this.aggregator);
        int ic = 0;
        while ((long)ic < commonSize) {
            double va = a.get((long)ic);
            double vb = b.get((long)ic);
            agg.aggregate(this.combinator.applyAsDouble(va, vb));
            ++ic;
        }
        builder.cellByDirectIndex(0L, agg.aggregatedValue());
        return builder.build();
    }

    private Tensor vectorMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType, boolean swapped) {
        if (a.type().dimensions().size() != 1 || b.type().dimensions().size() != 2) {
            throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-matrix product");
        }
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
        DimensionSizes sizesA = a.dimensionSizes();
        DimensionSizes sizesB = b.dimensionSizes();
        Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(this.aggregator);
        int ib = 0;
        while ((long)ib < sizesB.size(0)) {
            agg.reset();
            int ic = 0;
            while ((long)ic < Math.min(sizesA.size(0), sizesB.size(1))) {
                double va = a.get((long)ic);
                double vb = b.get((long)ib * sizesB.size(1) + (long)ic);
                double result = swapped ? this.combinator.applyAsDouble(vb, va) : this.combinator.applyAsDouble(va, vb);
                agg.aggregate(result);
                ++ic;
            }
            builder.cellByDirectIndex(ib, agg.aggregatedValue());
            ++ib;
        }
        return builder.build();
    }

    private Tensor matrixMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
        int ibToReduced;
        if (a.type().dimensions().size() != 2 || b.type().dimensions().size() != 2) {
            throw new IllegalArgumentException("Wrong dimension sizes for tensors for matrix-matrix product");
        }
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
        DimensionSizes sizesA = a.dimensionSizes();
        DimensionSizes sizesB = b.dimensionSizes();
        int iaToReduced = reducedType.indexOfDimension(a.type().dimensions().get(0).name()).get();
        long strideA = iaToReduced < (ibToReduced = reducedType.indexOfDimension(b.type().dimensions().get(0).name()).get().intValue()) ? sizesB.size(0) : 1L;
        long strideB = ibToReduced < iaToReduced ? sizesA.size(0) : 1L;
        Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(this.aggregator);
        int ia = 0;
        while ((long)ia < sizesA.size(0)) {
            int ib = 0;
            while ((long)ib < sizesB.size(0)) {
                agg.reset();
                int ic = 0;
                while ((long)ic < Math.min(sizesA.size(1), sizesB.size(1))) {
                    double va = a.get((long)ia * sizesA.size(1) + (long)ic);
                    double vb = b.get((long)ib * sizesB.size(1) + (long)ic);
                    agg.aggregate(this.combinator.applyAsDouble(va, vb));
                    ++ic;
                }
                builder.cellByDirectIndex((long)ia * strideA + (long)ib * strideB, agg.aggregatedValue());
                ++ib;
            }
            ++ia;
        }
        return builder.build();
    }

    private Tensor evaluateGeneral(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
        IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
        TensorType onlyInA = Reduce.outputType(a.type(), this.dimensions);
        TensorType onlyInB = Reduce.outputType(b.type(), this.dimensions);
        TensorType common = this.dimensionsInCommon(a, b);
        long[] stridesA = this.strides(a.type());
        long[] stridesB = this.strides(b.type());
        long[] stridesResult = this.strides(reducedType);
        int[] mapOnlyAToA = Join.mapIndexes(onlyInA, a.type());
        int[] mapCommonToA = Join.mapIndexes(common, a.type());
        int[] mapOnlyBToB = Join.mapIndexes(onlyInB, b.type());
        int[] mapCommonToB = Join.mapIndexes(common, b.type());
        int[] mapOnlyAToResult = Join.mapIndexes(onlyInA, reducedType);
        int[] mapOnlyBToResult = Join.mapIndexes(onlyInB, reducedType);
        MultiDimensionIterator ic = new MultiDimensionIterator(common);
        Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(this.aggregator);
        MultiDimensionIterator ia = new MultiDimensionIterator(onlyInA);
        while (ia.hasNext()) {
            MultiDimensionIterator ib = new MultiDimensionIterator(onlyInB);
            while (ib.hasNext()) {
                agg.reset();
                ic.reset();
                while (ic.hasNext()) {
                    double va = a.get(this.toDirectIndex(ia, ic, stridesA, mapOnlyAToA, mapCommonToA));
                    double vb = b.get(this.toDirectIndex(ib, ic, stridesB, mapOnlyBToB, mapCommonToB));
                    agg.aggregate(this.combinator.applyAsDouble(va, vb));
                    ic.next();
                }
                builder.cellByDirectIndex(this.toDirectIndex(ia, ib, stridesResult, mapOnlyAToResult, mapOnlyBToResult), agg.aggregatedValue());
                ib.next();
            }
            ia.next();
        }
        return builder.build();
    }

    private long toDirectIndex(MultiDimensionIterator iter, MultiDimensionIterator common, long[] strides, int[] map, int[] commonmap) {
        int i;
        long directIndex = 0L;
        for (i = 0; i < iter.length(); ++i) {
            directIndex += strides[map[i]] * iter.iterator[i];
        }
        for (i = 0; i < common.length(); ++i) {
            directIndex += strides[commonmap[i]] * common.iterator[i];
        }
        return directIndex;
    }

    private long[] strides(TensorType type) {
        long[] strides = new long[type.dimensions().size()];
        if (strides.length > 0) {
            long previous;
            strides[strides.length - 1] = previous = 1L;
            for (int i = strides.length - 2; i >= 0; --i) {
                strides[i] = previous * type.dimensions().get(i + 1).size().get();
                previous = strides[i];
            }
        }
        return strides;
    }

    private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) {
        TensorType.Builder builder = new TensorType.Builder();
        for (TensorType.Dimension aDim : a.type().dimensions()) {
            for (TensorType.Dimension bDim : b.type().dimensions()) {
                if (!aDim.name().equals(bDim.name())) continue;
                if (!aDim.size().isPresent()) {
                    builder.set(aDim);
                    continue;
                }
                if (!bDim.size().isPresent()) {
                    builder.set(bDim);
                    continue;
                }
                builder.set(aDim.size().get() < bDim.size().get() ? aDim : bDim);
            }
        }
        return builder.build();
    }

    private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) {
        List<String> reducingDimensions = this.dimensions;
        if (reducingDimensions.isEmpty()) {
            reducingDimensions = this.dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b).dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList());
        }
        if (reducingDimensions.size() != 1) {
            return false;
        }
        String dimension = reducingDimensions.get(0);
        int indexInA = a.type().indexOfDimension(dimension).orElseThrow(() -> new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor A."));
        if (indexInA != a.type().dimensions().size() - 1) {
            return false;
        }
        int indexInB = b.type().indexOfDimension(dimension).orElseThrow(() -> new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor B."));
        return indexInB >= b.type().dimensions().size() - 1;
    }

    @Override
    public String toString(ToStringContext context) {
        return "reduce_join(" + this.argumentA.toString(context) + ", " + this.argumentB.toString(context) + ", " + this.combinator + ", " + (Object)((Object)this.aggregator) + Reduce.commaSeparated(this.dimensions) + ")";
    }

    private static class MultiDimensionIterator {
        private long[] bounds;
        private long[] iterator;
        private int remaining;

        MultiDimensionIterator(TensorType type) {
            this.bounds = new long[type.dimensions().size()];
            this.iterator = new long[type.dimensions().size()];
            for (int i = 0; i < this.bounds.length; ++i) {
                this.bounds[i] = type.dimensions().get(i).size().get();
            }
            this.reset();
        }

        public int length() {
            return this.iterator.length;
        }

        public boolean hasNext() {
            return this.remaining > 0;
        }

        public void reset() {
            this.remaining = 1;
            for (int i = this.iterator.length - 1; i >= 0; --i) {
                this.iterator[i] = 0L;
                this.remaining = (int)((long)this.remaining * this.bounds[i]);
            }
        }

        public void next() {
            for (int i = this.iterator.length - 1; i >= 0; --i) {
                int n = i;
                this.iterator[n] = this.iterator[n] + 1L;
                if (this.iterator[i] < this.bounds[i]) break;
                this.iterator[i] = 0L;
            }
            --this.remaining;
        }

        public String toString() {
            return Arrays.toString(this.iterator);
        }
    }
}

