/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class Concat<NAMETYPE extends Name>
extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argumentA;
    private final TensorFunction<NAMETYPE> argumentB;
    private final String dimension;

    public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
        Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
        Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
        Objects.requireNonNull(dimension, "The dimension cannot be null");
        this.argumentA = argumentA;
        this.argumentB = argumentB;
        this.dimension = dimension;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argumentA, this.argumentB);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 2) {
            throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
        }
        return new Concat<NAMETYPE>(arguments.get(0), arguments.get(1), this.dimension);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new Concat<NAMETYPE>(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.dimension);
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "concat(" + this.argumentA.toString(context) + ", " + this.argumentB.toString(context) + ", " + this.dimension + ")";
    }

    @Override
    public int hashCode() {
        return Objects.hash("concat", this.argumentA, this.argumentB, this.dimension);
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return TypeResolver.concat(this.argumentA.type(context), this.argumentB.type(context), this.dimension);
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor a = this.argumentA.evaluate(context);
        Tensor b = this.argumentB.evaluate(context);
        if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
            return this.oldEvaluate(a, b);
        }
        Helper helper = new Helper(a, b, this.dimension);
        return helper.result;
    }

    private Tensor oldEvaluate(Tensor a, Tensor b) {
        TensorType concatType = TypeResolver.concat(a.type(), b.type(), this.dimension);
        a = this.ensureIndexedDimension(this.dimension, a, concatType.valueType());
        b = this.ensureIndexedDimension(this.dimension, b, concatType.valueType());
        IndexedTensor aIndexed = (IndexedTensor)a;
        IndexedTensor bIndexed = (IndexedTensor)b;
        DimensionSizes concatSize = this.concatSize(concatType, aIndexed, bIndexed, this.dimension);
        Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
        long aDimensionLength = aIndexed.type().indexOfDimension(this.dimension).map(d -> aIndexed.dimensionSizes().size((int)d)).orElseThrow(RuntimeException::new);
        int[] aToIndexes = this.mapIndexes(a.type(), concatType);
        int[] bToIndexes = this.mapIndexes(b.type(), concatType);
        this.concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
        this.concatenateTo(bIndexed, aIndexed, 0L, concatType, bToIndexes, aToIndexes, builder);
        return builder.build();
    }

    private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
        Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(this.dimension)).collect(Collectors.toSet());
        Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions);
        while (ia.hasNext()) {
            IndexedTensor.SubspaceIterator iaSubspace = ia.next();
            TensorAddress aAddress = iaSubspace.address();
            Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions);
            while (ib.hasNext()) {
                IndexedTensor.SubspaceIterator ibSubspace = ib.next();
                while (ibSubspace.hasNext()) {
                    Tensor.Cell bCell = ibSubspace.next();
                    TensorAddress combinedAddress = this.combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, concatType, offset, this.dimension);
                    if (combinedAddress == null) continue;
                    builder.cell(combinedAddress, (double)bCell.getValue());
                }
                iaSubspace.reset();
            }
        }
    }

    private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
        Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
        if (dimension.isPresent()) {
            if (!dimension.get().isIndexed()) {
                throw new IllegalArgumentException("Concat in dimension '" + dimensionName + "' requires that dimension to be indexed or absent, but got a tensor with type " + tensor.type());
            }
            return tensor;
        }
        if (tensor.type().hasMappedDimensions()) {
            throw new IllegalArgumentException("Concat requires an indexed tensor, but got a tensor with type " + tensor.type());
        }
        Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType).indexed(dimensionName, 1L).build()).cell(1.0f, 0L).build();
        return tensor.multiply(unitTensor);
    }

    private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
        DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
        for (int i = 0; i < concatSizes.dimensions(); ++i) {
            String currentDimension = concatType.dimensions().get(i).name();
            long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size((int)d)).orElse(0L);
            long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size((int)d)).orElse(0L);
            if (currentDimension.equals(concatDimension)) {
                concatSizes.set(i, aSize + bSize);
                continue;
            }
            if (aSize != 0L && bSize != 0L && aSize != bSize) {
                concatSizes.set(i, Math.min(aSize, bSize));
                continue;
            }
            concatSizes.set(i, Math.max(aSize, bSize));
        }
        return concatSizes.build();
    }

    private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType concatType, long concatOffset, String concatDimension) {
        long[] combinedLabels = new long[concatType.dimensions().size()];
        Arrays.fill(combinedLabels, -1L);
        int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
        this.mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset);
        boolean compatible = this.mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset);
        if (!compatible) {
            return null;
        }
        return TensorAddress.of(combinedLabels);
    }

    private int[] mapIndexes(TensorType fromType, TensorType toType) {
        int[] toIndexes = new int[fromType.dimensions().size()];
        for (int i = 0; i < fromType.dimensions().size(); ++i) {
            toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
        }
        return toIndexes;
    }

    private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
        for (int i = 0; i < from.size(); ++i) {
            int toIndex = indexMap[i];
            if (concatDimension == toIndex) {
                to[toIndex] = from.numericLabel(i) + concatOffset;
                continue;
            }
            if (to[toIndex] != -1L && to[toIndex] != from.numericLabel(i)) {
                return false;
            }
            to[toIndex] = from.numericLabel(i);
        }
        return true;
    }

    static class Helper {
        ConcatPlan plan;
        Tensor result;

        Helper(Tensor a, Tensor b, String dimension) {
            this.plan = new ConcatPlan(a.type(), b.type(), dimension);
            CellVectorMapMap aData = this.decompose(a, this.plan.splitInfoA);
            CellVectorMapMap bData = this.decompose(b, this.plan.splitInfoB);
            this.result = this.merge(aData, bData);
        }

        static int concatDimensionSize(CellVectorMapMap data) {
            HashSet sizes = new HashSet();
            data.map.forEach((m, cvmap) -> cvmap.map.forEach((e, vector) -> sizes.add(vector.values.size())));
            if (sizes.isEmpty()) {
                return 1;
            }
            if (sizes.size() == 1) {
                return (Integer)sizes.iterator().next();
            }
            throw new IllegalArgumentException("inconsistent size of concat dimension, had " + sizes.size() + " different values");
        }

        TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) {
            String[] labels = new String[this.plan.resultType.rank()];
            int out = 0;
            int m = 0;
            int a = 0;
            int b = 0;
            block6: for (ConcatPlan.CombineHow how : this.plan.combineHow) {
                switch (how) {
                    case left: {
                        labels[out++] = leftOnly.label(a++);
                        continue block6;
                    }
                    case right: {
                        labels[out++] = rightOnly.label(b++);
                        continue block6;
                    }
                    case both: {
                        labels[out++] = match.label(m++);
                        continue block6;
                    }
                    case concat: {
                        labels[out++] = String.valueOf(concatDimIdx);
                        continue block6;
                    }
                }
                throw new IllegalArgumentException("cannot handle: " + how);
            }
            return StringTensorAddress.unsafeOf(labels);
        }

        Tensor merge(CellVectorMapMap a, CellVectorMapMap b) {
            Tensor.Builder builder = Tensor.Builder.of(this.plan.resultType);
            int aConcatSize = Helper.concatDimensionSize(a);
            for (Map.Entry<TensorAddress, CellVectorMap> entry : a.map.entrySet()) {
                TensorAddress common = entry.getKey();
                if (!b.map.containsKey(common)) continue;
                CellVectorMap lhs = entry.getValue();
                CellVectorMap rhs = b.map.get(common);
                lhs.map.forEach((leftOnly, leftCells) -> rhs.map.forEach((rightOnly, rightCells) -> {
                    TensorAddress addr;
                    int i;
                    for (i = 0; i < leftCells.values.size(); ++i) {
                        addr = this.combine(common, (TensorAddress)leftOnly, (TensorAddress)rightOnly, i);
                        builder.cell(addr, (double)leftCells.values.get(i));
                    }
                    for (i = 0; i < rightCells.values.size(); ++i) {
                        addr = this.combine(common, (TensorAddress)leftOnly, (TensorAddress)rightOnly, i + aConcatSize);
                        builder.cell(addr, (double)rightCells.values.get(i));
                    }
                }));
            }
            return builder.build();
        }

        CellVectorMapMap decompose(Tensor input, SplitHow how) {
            Iterator<Tensor.Cell> iter = input.cellIterator();
            String[] commonLabels = new String[(int)how.numCommon()];
            String[] separateLabels = new String[(int)how.numSeparate()];
            CellVectorMapMap result = new CellVectorMapMap();
            while (iter.hasNext()) {
                Tensor.Cell cell = iter.next();
                TensorAddress addr = cell.getKey();
                long ccDimIndex = 0L;
                int commonIdx = 0;
                int separateIdx = 0;
                block6: for (int i = 0; i < how.handleDims.size(); ++i) {
                    switch (how.handleDims.get(i)) {
                        case common: {
                            commonLabels[commonIdx++] = addr.label(i);
                            continue block6;
                        }
                        case separate: {
                            separateLabels[separateIdx++] = addr.label(i);
                            continue block6;
                        }
                        case concat: {
                            ccDimIndex = addr.numericLabel(i);
                            continue block6;
                        }
                        default: {
                            throw new IllegalArgumentException("cannot handle: " + (Object)((Object)how.handleDims.get(i)));
                        }
                    }
                }
                TensorAddress commonAddr = TensorAddress.of(commonLabels);
                TensorAddress separateAddr = TensorAddress.of(separateLabels);
                result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue());
            }
            return result;
        }
    }

    static class ConcatPlan {
        final TensorType resultType;
        final String concatDimension;
        SplitHow splitInfoA = new SplitHow();
        SplitHow splitInfoB = new SplitHow();
        List<CombineHow> combineHow = new ArrayList<CombineHow>();

        void aOnly(String dimName) {
            if (dimName.equals(this.concatDimension)) {
                this.splitInfoA.handleDims.add(DimType.concat);
                this.combineHow.add(CombineHow.concat);
            } else {
                this.splitInfoA.handleDims.add(DimType.separate);
                this.combineHow.add(CombineHow.left);
            }
        }

        void bOnly(String dimName) {
            if (dimName.equals(this.concatDimension)) {
                this.splitInfoB.handleDims.add(DimType.concat);
                this.combineHow.add(CombineHow.concat);
            } else {
                this.splitInfoB.handleDims.add(DimType.separate);
                this.combineHow.add(CombineHow.right);
            }
        }

        void bothAandB(String dimName) {
            if (dimName.equals(this.concatDimension)) {
                this.splitInfoA.handleDims.add(DimType.concat);
                this.splitInfoB.handleDims.add(DimType.concat);
                this.combineHow.add(CombineHow.concat);
            } else {
                this.splitInfoA.handleDims.add(DimType.common);
                this.splitInfoB.handleDims.add(DimType.common);
                this.combineHow.add(CombineHow.both);
            }
        }

        ConcatPlan(TensorType aType, TensorType bType, String concatDimension) {
            this.resultType = TypeResolver.concat(aType, bType, concatDimension);
            this.concatDimension = concatDimension;
            List<TensorType.Dimension> aDims = aType.dimensions();
            List<TensorType.Dimension> bDims = bType.dimensions();
            int i = 0;
            int j = 0;
            while (i < aDims.size() && j < bDims.size()) {
                String bName;
                String aName = aDims.get(i).name();
                int cmp = aName.compareTo(bName = bDims.get(j).name());
                if (cmp == 0) {
                    this.bothAandB(aName);
                    ++i;
                    ++j;
                    continue;
                }
                if (cmp < 0) {
                    this.aOnly(aName);
                    ++i;
                    continue;
                }
                this.bOnly(bName);
                ++j;
            }
            while (i < aDims.size()) {
                this.aOnly(aDims.get(i++).name());
            }
            while (j < bDims.size()) {
                this.bOnly(bDims.get(j++).name());
            }
            if (this.combineHow.size() < this.resultType.rank()) {
                Optional<Integer> idx = this.resultType.indexOfDimension(concatDimension);
                this.combineHow.add(idx.get(), CombineHow.concat);
            }
        }

        static enum CombineHow {
            left,
            right,
            both,
            concat;

        }
    }

    static class SplitHow {
        List<DimType> handleDims = new ArrayList<DimType>();

        SplitHow() {
        }

        long numCommon() {
            return this.handleDims.stream().filter(t -> t == DimType.common).count();
        }

        long numSeparate() {
            return this.handleDims.stream().filter(t -> t == DimType.separate).count();
        }
    }

    static class CellVectorMapMap {
        Map<TensorAddress, CellVectorMap> map = new HashMap<TensorAddress, CellVectorMap>();

        CellVectorMapMap() {
        }

        CellVectorMap lookupCreate(TensorAddress addr) {
            return this.map.computeIfAbsent(addr, k -> new CellVectorMap());
        }
    }

    static class CellVectorMap {
        Map<TensorAddress, CellVector> map = new HashMap<TensorAddress, CellVector>();

        CellVectorMap() {
        }

        CellVector lookupCreate(TensorAddress addr) {
            return this.map.computeIfAbsent(addr, k -> new CellVector());
        }
    }

    static class CellVector {
        ArrayList<Double> values = new ArrayList();

        CellVector() {
        }

        void setValue(int ccDimIndex, double value) {
            while (this.values.size() <= ccDimIndex) {
                this.values.add(0.0);
            }
            this.values.set(ccDimIndex, value);
        }
    }

    static enum DimType {
        common,
        separate,
        concat;

    }
}

