/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class Softmax
extends IntermediateOperation {
    private final IntermediateOperation.AttributeMap attributeMap;

    public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs, IntermediateOperation.AttributeMap attributeMap) {
        super(modelName, nodeName, inputs);
        this.attributeMap = attributeMap;
        this.insert(new SoftmaxPartialOperation(modelName, nodeName, null), 0);
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (!this.allInputTypesPresent(1)) {
            return null;
        }
        return ((IntermediateOperation)this.inputs.get(0)).type().get();
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (!this.allInputFunctionsPresent(1)) {
            return null;
        }
        List<String> reduceDimensions = this.reduceDimensions();
        TensorFunction input = ((IntermediateOperation)this.inputs.get(0)).function().get();
        Reduce sum = new Reduce(input, Reduce.Aggregator.sum, reduceDimensions);
        Join div = new Join(input, (TensorFunction)sum, ScalarFunctions.divide());
        return div;
    }

    @Override
    public Softmax withInputs(List<IntermediateOperation> inputs) {
        return new Softmax(this.modelName(), this.name(), inputs, this.attributeMap);
    }

    @Override
    public String operationName() {
        return "SoftMax";
    }

    private List<String> reduceDimensions() {
        int axis;
        OrderedTensorType inputType = ((IntermediateOperation)this.inputs.get(0)).type().get();
        int n = axis = inputType.rank() == 1 ? 0 : 1;
        if (this.attributeMap.get("axis").isPresent()) {
            axis = (int)this.attributeMap.get("axis").get().asDouble();
        }
        if (axis < 0) {
            axis = inputType.rank() + axis;
        }
        ArrayList<String> reduceDimensions = new ArrayList<String>();
        for (int i = axis; i < inputType.rank(); ++i) {
            reduceDimensions.add(inputType.dimensions().get(i).name());
        }
        return reduceDimensions;
    }

    private class SoftmaxPartialOperation
    extends IntermediateOperation {
        private SoftmaxPartialOperation(String modelName, String nodeName, List<IntermediateOperation> inputs) {
            super(modelName, nodeName + "_partial", inputs != null ? inputs : Collections.emptyList());
        }

        @Override
        protected OrderedTensorType lazyGetType() {
            if (!this.allInputTypesPresent(1)) {
                return null;
            }
            ((IntermediateOperation)this.inputs.get((int)0)).exportAsRankingFunction = true;
            this.exportAsRankingFunction = true;
            return ((IntermediateOperation)this.inputs.get(0)).type().get();
        }

        @Override
        protected TensorFunction lazyGetFunction() {
            if (!this.allInputFunctionsPresent(1)) {
                return null;
            }
            List<String> reduceDimensions = Softmax.this.reduceDimensions();
            TensorFunction input = ((IntermediateOperation)this.inputs.get(0)).function().get();
            Reduce max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions);
            Join cap = new Join(input, (TensorFunction)max, ScalarFunctions.subtract());
            Map exp = new Map((TensorFunction)cap, ScalarFunctions.exp());
            return exp;
        }

        @Override
        public SoftmaxPartialOperation withInputs(List<IntermediateOperation> inputs) {
            return new SoftmaxPartialOperation(this.modelName(), this.name(), inputs);
        }

        @Override
        public String operationName() {
            return "SoftMaxPartial";
        }
    }
}

