/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

@Beta
public class Concat
extends PrimitiveTensorFunction {
    private final TensorFunction argumentA;
    private final TensorFunction argumentB;
    private final String dimension;

    public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) {
        Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
        Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
        Objects.requireNonNull(dimension, "The dimension cannot be null");
        this.argumentA = argumentA;
        this.argumentB = argumentB;
        this.dimension = dimension;
    }

    @Override
    public List<TensorFunction> functionArguments() {
        return ImmutableList.of((Object)this.argumentA, (Object)this.argumentB);
    }

    @Override
    public TensorFunction replaceArguments(List<TensorFunction> arguments) {
        if (arguments.size() != 2) {
            throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
        }
        return new Concat(arguments.get(0), arguments.get(1), this.dimension);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() {
        return new Concat(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.dimension);
    }

    @Override
    public String toString(ToStringContext context) {
        return "concat(" + this.argumentA.toString(context) + ", " + this.argumentB.toString(context) + ", " + this.dimension + ")";
    }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor a = this.argumentA.evaluate(context);
        Tensor b = this.argumentB.evaluate(context);
        a = this.ensureIndexedDimension(this.dimension, a);
        b = this.ensureIndexedDimension(this.dimension, b);
        IndexedTensor aIndexed = (IndexedTensor)a;
        IndexedTensor bIndexed = (IndexedTensor)b;
        TensorType concatType = this.concatType(a, b);
        DimensionSizes concatSize = this.concatSize(concatType, aIndexed, bIndexed, this.dimension);
        Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
        int aDimensionLength = aIndexed.type().indexOfDimension(this.dimension).map(d -> aIndexed.dimensionSizes().size((int)d)).orElseThrow(RuntimeException::new);
        int[] aToIndexes = this.mapIndexes(a.type(), concatType);
        int[] bToIndexes = this.mapIndexes(b.type(), concatType);
        this.concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
        this.concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
        return builder.build();
    }

    private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
        Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(this.dimension)).collect(Collectors.toSet());
        Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions);
        while (ia.hasNext()) {
            IndexedTensor.SubspaceIterator iaSubspace = ia.next();
            TensorAddress aAddress = iaSubspace.address();
            Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions);
            while (ib.hasNext()) {
                IndexedTensor.SubspaceIterator ibSubspace = ib.next();
                while (ibSubspace.hasNext()) {
                    Tensor.Cell bCell = ibSubspace.next();
                    TensorAddress combinedAddress = this.combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, concatType, offset, this.dimension);
                    if (combinedAddress == null) continue;
                    builder.cell(combinedAddress, (double)bCell.getValue());
                }
                iaSubspace.reset();
            }
        }
    }

    private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor) {
        Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
        if (dimension.isPresent()) {
            if (!dimension.get().isIndexed()) {
                throw new IllegalArgumentException("Concat in dimension '" + dimensionName + "' requires that dimension to be indexed or absent, but got a tensor with type " + tensor.type());
            }
            return tensor;
        }
        if (tensor.type().dimensions().stream().anyMatch(d -> !d.isIndexed())) {
            throw new IllegalArgumentException("Concat requires an indexed tensor, but got a tensor with type " + tensor.type());
        }
        Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1.0, 0).build();
        return tensor.multiply(unitTensor);
    }

    private TensorType concatType(Tensor a, Tensor b) {
        TensorType.Builder builder = new TensorType.Builder(a.type(), b.type());
        if (builder.getDimension(this.dimension).get().size().isPresent()) {
            builder.set(TensorType.Dimension.indexed(this.dimension, a.type().dimension(this.dimension).get().size().get() + b.type().dimension(this.dimension).get().size().get()));
        }
        return builder.build();
    }

    private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
        DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
        for (int i = 0; i < concatSizes.dimensions(); ++i) {
            String currentDimension = concatType.dimensions().get(i).name();
            int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size((int)d)).orElse(0);
            int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size((int)d)).orElse(0);
            if (currentDimension.equals(concatDimension)) {
                concatSizes.set(i, aSize + bSize);
                continue;
            }
            if (aSize != 0 && bSize != 0 && aSize != bSize) {
                throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when concatenating " + a.type() + " and " + b.type() + " along dimension " + concatDimension + ", but was " + aSize + " and " + bSize);
            }
            concatSizes.set(i, Math.max(aSize, bSize));
        }
        return concatSizes.build();
    }

    private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType concatType, int concatOffset, String concatDimension) {
        int[] combinedLabels = new int[concatType.dimensions().size()];
        Arrays.fill(combinedLabels, -1);
        int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
        this.mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset);
        boolean compatible = this.mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset);
        if (!compatible) {
            return null;
        }
        return TensorAddress.of(combinedLabels);
    }

    private int[] mapIndexes(TensorType fromType, TensorType toType) {
        int[] toIndexes = new int[fromType.dimensions().size()];
        for (int i = 0; i < fromType.dimensions().size(); ++i) {
            toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
        }
        return toIndexes;
    }

    private boolean mapContent(TensorAddress from, int[] to, int[] indexMap, int concatDimension, int concatOffset) {
        for (int i = 0; i < from.size(); ++i) {
            int toIndex = indexMap[i];
            if (concatDimension == toIndex) {
                to[toIndex] = from.intLabel(i) + concatOffset;
                continue;
            }
            if (to[toIndex] != -1 && to[toIndex] != from.intLabel(i)) {
                return false;
            }
            to[toIndex] = from.intLabel(i);
        }
        return true;
    }
}

