/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;

import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.tensorflow.framework.NodeDef;

public class Reshape
extends TensorFlowOperation {
    public Reshape(NodeDef node, List<TensorFlowOperation> inputs, int port) {
        super(node, inputs, port);
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        TensorFlowOperation newShape = (TensorFlowOperation)this.inputs.get(1);
        if (!newShape.getConstantValue().isPresent()) {
            throw new IllegalArgumentException("Reshape in " + this.node.getName() + ": shape input must be a constant.");
        }
        Tensor shape = newShape.getConstantValue().get().asTensor();
        OrderedTensorType inputType = ((TensorFlowOperation)this.inputs.get(0)).type().get();
        OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(this.node);
        int dimensionIndex = 0;
        Iterator cellIterator = shape.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell cell = (Tensor.Cell)cellIterator.next();
            int size = cell.getValue().intValue();
            if (size < 0) {
                size = -1 * (int)shape.reduce(Reduce.Aggregator.prod, new String[0]).asDouble() / TensorConverter.tensorSize(inputType.type()).intValue();
            }
            outputTypeBuilder.add(TensorType.Dimension.indexed((String)String.format("%s_%d", this.vespaName(), dimensionIndex), (long)size));
            ++dimensionIndex;
        }
        return outputTypeBuilder.build();
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (!this.allInputTypesPresent(2)) {
            return null;
        }
        if (!this.allInputFunctionsPresent(2)) {
            return null;
        }
        OrderedTensorType inputType = ((TensorFlowOperation)this.inputs.get(0)).type().get();
        TensorFunction inputFunction = ((TensorFlowOperation)this.inputs.get(0)).function().get();
        return Reshape.reshape(inputFunction, inputType.type(), this.type.type());
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        for (TensorType.Dimension dimension : this.type.type().dimensions()) {
            renamer.addDimension(dimension.name());
        }
    }

    public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
        if (!TensorConverter.tensorSize(inputType).equals(TensorConverter.tensorSize(outputType))) {
            throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
        }
        ExpressionNode unrollFrom = Reshape.unrollTensorExpression(inputType);
        ExpressionNode unrollTo = Reshape.unrollTensorExpression(outputType);
        ComparisonNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
        TensorType transformationType = new TensorType.Builder(new TensorType[]{inputType, outputType}).build();
        Generate transformTensor = new Generate(transformationType, (Function)new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
        Reduce outputFunction = new Reduce((TensorFunction)new Join(inputFunction, (TensorFunction)transformTensor, ScalarFunctions.multiply()), Reduce.Aggregator.sum, inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
        return outputFunction;
    }

    private static ExpressionNode unrollTensorExpression(TensorType type) {
        if (type.rank() == 0) {
            return new ConstantNode(DoubleValue.zero);
        }
        ArrayList<ExpressionNode> children = new ArrayList<ExpressionNode>();
        ArrayList<ArithmeticOperator> operators = new ArrayList<ArithmeticOperator>();
        int size = 1;
        for (int i = type.dimensions().size() - 1; i >= 0; --i) {
            TensorType.Dimension dimension = (TensorType.Dimension)type.dimensions().get(i);
            children.add(0, new ReferenceNode(dimension.name()));
            if (size > 1) {
                operators.add(0, ArithmeticOperator.MULTIPLY);
                children.add(0, new ConstantNode(new DoubleValue(size)));
            }
            size = (int)((long)size * TensorConverter.dimensionSize(dimension));
            if (i <= 0) continue;
            operators.add(0, ArithmeticOperator.PLUS);
        }
        return new ArithmeticNode(children, operators);
    }
}

