/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.DenseSubspaceFunction;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public class MapSubspaces<NAMETYPE extends Name>
extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argument;
    private final DenseSubspaceFunction<NAMETYPE> function;

    private MapSubspaces(TensorFunction<NAMETYPE> argument, DenseSubspaceFunction<NAMETYPE> function) {
        this.argument = argument;
        this.function = function;
    }

    public MapSubspaces(TensorFunction<NAMETYPE> argument, String functionArg, TensorFunction<NAMETYPE> function) {
        this(argument, new DenseSubspaceFunction<NAMETYPE>(functionArg, function));
        Objects.requireNonNull(argument, "The argument cannot be null");
        Objects.requireNonNull(functionArg, "The functionArg cannot be null");
        Objects.requireNonNull(function, "The function cannot be null");
    }

    private TensorType outputType(TensorType inputType) {
        TensorType m = inputType.mappedSubtype();
        TensorType d = this.function.outputType(inputType.indexedSubtype());
        if (m.rank() == 0) {
            return d;
        }
        if (d.rank() == 0) {
            return TypeResolver.map(m);
        }
        TensorType.Value cellType = d.valueType();
        HashMap<String, TensorType.Dimension> dims = new HashMap<String, TensorType.Dimension>();
        for (TensorType.Dimension dim : m.dimensions()) {
            dims.put(dim.name(), dim);
        }
        for (TensorType.Dimension dim : d.dimensions()) {
            TensorType.Dimension old = dims.put(dim.name(), dim);
            if (old == null) continue;
            throw new IllegalArgumentException("dimension name collision in map_subspaces: " + m + " vs " + d);
        }
        return new TensorType(cellType, dims.values());
    }

    public TensorFunction<NAMETYPE> argument() {
        return this.argument;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argument);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 1) {
            throw new IllegalArgumentException("MapSubspaces must have 1 argument, got " + arguments.size());
        }
        return new MapSubspaces<NAMETYPE>(arguments.get(0), this.function);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new MapSubspaces<NAMETYPE>(this.argument.toPrimitive(), this.function);
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return this.outputType(this.argument.type(context));
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor input = this.argument().evaluate(context);
        TensorType inputType = input.type();
        TensorType inputTypeMapped = inputType.mappedSubtype();
        TensorType inputTypeDense = inputType.indexedSubtype();
        HashMap<TensorAddress, Tensor.Builder> builders = new HashMap<TensorAddress, Tensor.Builder>();
        Iterator<Tensor.Cell> iter = input.cellIterator();
        while (iter.hasNext()) {
            Tensor.Cell cell = iter.next();
            TensorAddress fullAddr = cell.getKey();
            TensorAddress.Builder mapAddrBuilder = new TensorAddress.Builder(inputTypeMapped);
            TensorAddress.Builder idxAddrBuilder = new TensorAddress.Builder(inputTypeDense);
            for (int i = 0; i < inputType.dimensions().size(); ++i) {
                TensorType.Dimension dim = inputType.dimensions().get(i);
                if (dim.isMapped()) {
                    mapAddrBuilder.add(dim.name(), fullAddr.objectLabel(i));
                    continue;
                }
                idxAddrBuilder.add(dim.name(), fullAddr.objectLabel(i));
            }
            TensorAddress mapAddr = mapAddrBuilder.build();
            Tensor.Builder builder = builders.computeIfAbsent(mapAddr, k -> Tensor.Builder.of(inputTypeDense));
            TensorAddress idxAddr = idxAddrBuilder.build();
            builder.cell(idxAddr, (double)cell.getValue());
        }
        TensorType outputType = this.outputType(input.type());
        TensorType denseOutputType = outputType.indexedSubtype();
        List<TensorType.Dimension> denseOutputDims = denseOutputType.dimensions();
        Tensor.Builder builder = Tensor.Builder.of(outputType);
        for (Map.Entry entry : builders.entrySet()) {
            TensorAddress mappedAddr = (TensorAddress)entry.getKey();
            Tensor denseInput = ((Tensor.Builder)entry.getValue()).build();
            Tensor denseOutput = this.function.map(denseInput);
            Iterator<Tensor.Cell> iter2 = denseOutput.cellIterator();
            while (iter2.hasNext()) {
                TensorType.Dimension dim;
                int i;
                Tensor.Cell cell = iter2.next();
                TensorAddress denseAddr = cell.getKey();
                TensorAddress.Builder addrBuilder = new TensorAddress.Builder(outputType);
                for (i = 0; i < inputTypeMapped.dimensions().size(); ++i) {
                    dim = inputTypeMapped.dimensions().get(i);
                    addrBuilder.add(dim.name(), mappedAddr.objectLabel(i));
                }
                for (i = 0; i < denseOutputDims.size(); ++i) {
                    dim = denseOutputDims.get(i);
                    addrBuilder.add(dim.name(), denseAddr.objectLabel(i));
                }
                builder.cell(addrBuilder.build(), (double)cell.getValue());
            }
        }
        return builder.build();
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "map_subspaces(" + this.argument.toString(context) + ", " + this.function + ")";
    }

    @Override
    public int hashCode() {
        return Objects.hash("map_subspaces", this.argument, this.function);
    }
}

