/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;

public class Join<NAMETYPE extends Name>
extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argumentA;
    private final TensorFunction<NAMETYPE> argumentB;
    private final DoubleBinaryOperator combinator;

    public Join(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator combinator) {
        Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
        Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
        Objects.requireNonNull(combinator, "The combinator function cannot be null");
        this.argumentA = argumentA;
        this.argumentB = argumentB;
        this.combinator = combinator;
    }

    public static TensorType outputType(TensorType a, TensorType b) {
        try {
            return TypeResolver.join(a, b);
        }
        catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
        }
    }

    public DoubleBinaryOperator combinator() {
        return this.combinator;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return ImmutableList.of(this.argumentA, this.argumentB);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 2) {
            throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
        }
        return new Join<NAMETYPE>(arguments.get(0), arguments.get(1), this.combinator);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new Join<NAMETYPE>(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.combinator);
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "join(" + this.argumentA.toString(context) + ", " + this.argumentB.toString(context) + ", " + this.combinator + ")";
    }

    @Override
    public int hashCode() {
        return Objects.hash("join", this.argumentA, this.argumentB, this.combinator);
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return Join.outputType(this.argumentA.type(context), this.argumentB.type(context));
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor a = this.argumentA.evaluate(context);
        Tensor b = this.argumentB.evaluate(context);
        TensorType joinedType = Join.outputType(a.type(), b.type());
        return Join.evaluate(a, b, joinedType, this.combinator);
    }

    static Tensor evaluate(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        if (Join.hasSingleIndexedDimension(a) && Join.hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) {
            return Join.indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator);
        }
        if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) {
            return Join.singleSpaceJoin(a, b, joinedType, combinator);
        }
        if (a.type().dimensions().containsAll(b.type().dimensions())) {
            return Join.subspaceJoin(b, a, joinedType, true, combinator);
        }
        if (b.type().dimensions().containsAll(a.type().dimensions())) {
            return Join.subspaceJoin(a, b, joinedType, false, combinator);
        }
        return Join.generalJoin(a, b, joinedType, combinator);
    }

    private static boolean hasSingleIndexedDimension(Tensor tensor) {
        return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
    }

    private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
        long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
        Iterator<Double> aIterator = a.valueIterator();
        Iterator<Double> bIterator = b.valueIterator();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
        int i = 0;
        while ((long)i < joinedRank) {
            builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
            ++i;
        }
        return builder.build();
    }

    private static Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        Iterator<Tensor.Cell> i = a.cellIterator();
        while (i.hasNext()) {
            Map.Entry aCell = i.next();
            TensorAddress key = (TensorAddress)aCell.getKey();
            if (!b.has(key)) continue;
            builder.cell(key, combinator.applyAsDouble((Double)aCell.getValue(), b.get(key)));
        }
        return builder.build();
    }

    private static Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
        if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) {
            return Join.indexedSubspaceJoin((IndexedTensor)subspace, (IndexedTensor)superspace, joinedType, reversedArgumentOrder, combinator);
        }
        return Join.generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder, combinator);
    }

    private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
        if (subspace.size() == 0L || superspace.size() == 0L) {
            return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
        }
        DimensionSizes joinedSizes = Join.joinedSize(joinedType, subspace, superspace);
        IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
        HashSet<String> superDimensionNames = new HashSet<String>(superspace.type().dimensionNames());
        superDimensionNames.removeAll(subspace.type().dimensionNames());
        Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes);
        while (i.hasNext()) {
            IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
            Join.joinSubspaces(subspace.valueIterator(), subspace.size(), subspaceInSuper, subspaceInSuper.size(), reversedArgumentOrder, builder, combinator);
        }
        return builder.build();
    }

    private static void joinSubspaces(Iterator<Double> subspace, long subspaceSize, Iterator<Tensor.Cell> superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder, DoubleBinaryOperator combinator) {
        long joinedLength = Math.min(subspaceSize, superspaceSize);
        if (reversedArgumentOrder) {
            int i = 0;
            while ((long)i < joinedLength) {
                Tensor.Cell supercell = superspace.next();
                builder.cell(supercell, combinator.applyAsDouble(supercell.getValue(), subspace.next()));
                ++i;
            }
        } else {
            int i = 0;
            while ((long)i < joinedLength) {
                Tensor.Cell supercell = superspace.next();
                builder.cell(supercell, combinator.applyAsDouble(subspace.next(), supercell.getValue()));
                ++i;
            }
        }
    }

    private static DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size());
        for (int i = 0; i < builder.dimensions(); ++i) {
            String dimensionName = joinedType.dimensions().get(i).name();
            Optional<Integer> aIndex = a.type().indexOfDimension(dimensionName);
            Optional<Integer> bIndex = b.type().indexOfDimension(dimensionName);
            if (aIndex.isPresent() && bIndex.isPresent()) {
                builder.set(i, Math.min(b.dimensionSizes().size(bIndex.get()), a.dimensionSizes().size(aIndex.get())));
                continue;
            }
            if (aIndex.isPresent()) {
                builder.set(i, a.dimensionSizes().size(aIndex.get()));
                continue;
            }
            if (!bIndex.isPresent()) continue;
            builder.set(i, b.dimensionSizes().size(bIndex.get()));
        }
        return builder.build();
    }

    private static Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
        int[] subspaceIndexes = Join.subspaceIndexes(superspace.type(), subspace.type());
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        Iterator<Tensor.Cell> i = superspace.cellIterator();
        while (i.hasNext()) {
            Map.Entry supercell = i.next();
            TensorAddress subaddress = Join.mapAddressToSubspace((TensorAddress)supercell.getKey(), subspaceIndexes);
            if (!subspace.has(subaddress)) continue;
            double subspaceValue = subspace.get(subaddress);
            builder.cell((TensorAddress)supercell.getKey(), reversedArgumentOrder ? combinator.applyAsDouble((Double)supercell.getValue(), subspaceValue) : combinator.applyAsDouble(subspaceValue, (Double)supercell.getValue()));
        }
        return builder.build();
    }

    private static int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
        int[] subspaceIndexes = new int[subtype.dimensions().size()];
        for (int i = 0; i < subtype.dimensions().size(); ++i) {
            subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
        }
        return subspaceIndexes;
    }

    private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
        String[] subspaceLabels = new String[subspaceIndexes.length];
        for (int i = 0; i < subspaceIndexes.length; ++i) {
            subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
        }
        return TensorAddress.of(subspaceLabels);
    }

    private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
            return Join.indexedGeneralJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator);
        }
        return Join.mappedHashJoin(a, b, joinedType, combinator);
    }

    private static Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        DimensionSizes joinedSize = Join.joinedSize(joinedType, a, b);
        Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize);
        int[] aToIndexes = Join.mapIndexes(a.type(), joinedType);
        int[] bToIndexes = Join.mapIndexes(b.type(), joinedType);
        Join.joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, builder, combinator);
        return builder.build();
    }

    private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, DoubleBinaryOperator combinator) {
        Sets.SetView sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
        Sets.SetView dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
        DimensionSizes aIterateSize = Join.joinedSizeOf(a.type(), joinedType, joinedSize);
        DimensionSizes bIterateSize = Join.joinedSizeOf(b.type(), joinedType, joinedSize);
        Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator((Set<String>)dimensionsOnlyInA, aIterateSize);
        while (ia.hasNext()) {
            IndexedTensor.SubspaceIterator aSubspace = ia.next();
            while (aSubspace.hasNext()) {
                Tensor.Cell aCell = aSubspace.next();
                PartialAddress matchingBCells = Join.partialAddress(a.type(), aSubspace.address(), (Set<String>)sharedDimensions);
                IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize);
                while (bSubspace.hasNext()) {
                    Tensor.Cell bCell = bSubspace.next();
                    TensorAddress joinedAddress = Join.joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType);
                    double joinedValue = combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
                    builder.cell(joinedAddress, joinedValue);
                }
            }
        }
    }

    private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
        PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
        for (int i = 0; i < addressType.dimensions().size(); ++i) {
            if (!retainDimensions.contains(addressType.dimensions().get(i).name())) continue;
            builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
        }
        return builder.build();
    }

    private static DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
        int dimensionIndex = 0;
        for (int i = 0; i < joinedType.dimensions().size(); ++i) {
            if (!type.dimensionNames().contains(joinedType.dimensions().get(i).name())) continue;
            builder.set(dimensionIndex++, joinedSizes.size(i));
        }
        return builder.build();
    }

    private static Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        int[] aToIndexes = Join.mapIndexes(a.type(), joinedType);
        int[] bToIndexes = Join.mapIndexes(b.type(), joinedType);
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        Iterator<Tensor.Cell> aIterator = a.cellIterator();
        while (aIterator.hasNext()) {
            Map.Entry aCell = aIterator.next();
            Iterator<Tensor.Cell> bIterator = b.cellIterator();
            while (bIterator.hasNext()) {
                Map.Entry bCell = bIterator.next();
                TensorAddress combinedAddress = Join.joinAddresses((TensorAddress)aCell.getKey(), aToIndexes, (TensorAddress)bCell.getKey(), bToIndexes, joinedType);
                if (combinedAddress == null) continue;
                builder.cell(combinedAddress, combinator.applyAsDouble((Double)aCell.getValue(), (Double)bCell.getValue()));
            }
        }
        return builder.build();
    }

    private static Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
        boolean swapTensors;
        TensorType commonDimensionType = Join.commonDimensions(a, b);
        if (commonDimensionType.dimensions().isEmpty()) {
            return Join.mappedGeneralJoin(a, b, joinedType, combinator);
        }
        boolean bl = swapTensors = a.size() > b.size();
        if (swapTensors) {
            Tensor temp = a;
            a = b;
            b = temp;
        }
        int[] aIndexesInCommon = Join.mapIndexes(commonDimensionType, a.type());
        int[] bIndexesInCommon = Join.mapIndexes(commonDimensionType, b.type());
        int[] aIndexesInJoined = Join.mapIndexes(a.type(), joinedType);
        int[] bIndexesInJoined = Join.mapIndexes(b.type(), joinedType);
        HashMap<TensorAddress, List<Object>> aCellsByCommonAddress = new HashMap<TensorAddress, List<Object>>();
        Iterator<Tensor.Cell> cellIterator = a.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell aCell = cellIterator.next();
            TensorAddress partialCommonAddress = Join.partialCommonAddress(aCell, aIndexesInCommon);
            aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList());
            ((List)aCellsByCommonAddress.get(partialCommonAddress)).add(aCell);
        }
        Tensor.Builder builder = Tensor.Builder.of(joinedType);
        Iterator<Tensor.Cell> cellIterator2 = b.cellIterator();
        while (cellIterator2.hasNext()) {
            Tensor.Cell bCell = cellIterator2.next();
            TensorAddress partialCommonAddress = Join.partialCommonAddress(bCell, bIndexesInCommon);
            for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, Collections.emptyList())) {
                TensorAddress combinedAddress = Join.joinAddresses(aCell.getKey(), aIndexesInJoined, bCell.getKey(), bIndexesInJoined, joinedType);
                if (combinedAddress == null) continue;
                double combinedValue = swapTensors ? combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) : combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
                builder.cell(combinedAddress, combinedValue);
            }
        }
        return builder.build();
    }

    static int[] mapIndexes(TensorType fromType, TensorType toType) {
        int[] toIndexes = new int[fromType.dimensions().size()];
        for (int i = 0; i < fromType.dimensions().size(); ++i) {
            toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
        }
        return toIndexes;
    }

    private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType joinedType) {
        String[] joinedLabels = new String[joinedType.dimensions().size()];
        Join.mapContent(a, joinedLabels, aToIndexes);
        boolean compatible = Join.mapContent(b, joinedLabels, bToIndexes);
        if (!compatible) {
            return null;
        }
        return TensorAddress.of(joinedLabels);
    }

    private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
        for (int i = 0; i < from.size(); ++i) {
            int toIndex = indexMap[i];
            if (to[toIndex] != null && !to[toIndex].equals(from.label(i))) {
                return false;
            }
            to[toIndex] = from.label(i);
        }
        return true;
    }

    private static TensorType commonDimensions(Tensor a, Tensor b) {
        TensorType aType = a.type();
        TensorType bType = b.type();
        TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.combinedValueType(aType, bType));
        for (int i = 0; i < aType.dimensions().size(); ++i) {
            TensorType.Dimension aDim = aType.dimensions().get(i);
            for (int j = 0; j < bType.dimensions().size(); ++j) {
                TensorType.Dimension bDim = bType.dimensions().get(j);
                if (!aDim.equals(bDim)) continue;
                typeBuilder.set(bDim);
            }
        }
        return typeBuilder.build();
    }

    private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
        TensorAddress address = cell.getKey();
        String[] labels = new String[indexMap.length];
        for (int i = 0; i < labels.length; ++i) {
            labels[i] = address.label(indexMap[i]);
        }
        return TensorAddress.of(labels);
    }
}

