/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.evaluation;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

class OnnxModel
implements AutoCloseable {
    final List<InputSpec> inputSpecs = new ArrayList<InputSpec>();
    final List<OutputSpec> outputSpecs = new ArrayList<OutputSpec>();
    private final String name;
    private final File modelFile;
    private final OnnxEvaluatorOptions options;
    private final OnnxRuntime onnx;
    private OnnxEvaluator evaluator;

    void addInputMapping(String onnxName, String source) {
        if (this.evaluator != null) {
            throw new IllegalStateException("input mapping must be added before load()");
        }
        this.inputSpecs.add(new InputSpec(onnxName, source));
    }

    void addOutputMapping(String onnxName, String outputAs) {
        if (this.evaluator != null) {
            throw new IllegalStateException("output mapping must be added before load()");
        }
        this.outputSpecs.add(new OutputSpec(onnxName, outputAs));
    }

    OnnxModel(String name, File modelFile, OnnxEvaluatorOptions options, OnnxRuntime onnx) {
        this.name = name;
        this.modelFile = modelFile;
        this.options = options;
        this.onnx = onnx;
    }

    public String name() {
        return this.name;
    }

    public void load() {
        if (this.evaluator == null) {
            this.evaluator = this.onnx.evaluatorOf(this.modelFile.getPath(), this.options);
            this.fillInputTypes(this.evaluator().getInputs());
            this.fillOutputTypes(this.evaluator().getOutputs());
        }
    }

    void fillInputTypes(Map<String, OnnxEvaluator.IdAndType> wantedTypes) {
        if (this.inputSpecs.isEmpty()) {
            for (Map.Entry<String, OnnxEvaluator.IdAndType> entry : wantedTypes.entrySet()) {
                String name = entry.getKey();
                String source = entry.getValue().id();
                TensorType tType = entry.getValue().type();
                InputSpec spec = new InputSpec(name, source, tType);
                this.inputSpecs.add(spec);
            }
        } else {
            if (wantedTypes.size() != this.inputSpecs.size()) {
                throw new IllegalArgumentException("Onnx model " + this.name() + ": Mismatch between " + this.inputSpecs.size() + " configured inputs and " + wantedTypes.size() + " actual model inputs");
            }
            for (InputSpec spec : this.inputSpecs) {
                OnnxEvaluator.IdAndType entry = wantedTypes.get(spec.onnxName);
                if (entry == null) {
                    throw new IllegalArgumentException("Onnx model " + this.name() + ": No type in actual model for configured input " + spec.onnxName);
                }
                spec.wantedType = entry.type();
            }
        }
    }

    void fillOutputTypes(Map<String, OnnxEvaluator.IdAndType> outputTypes) {
        if (this.outputSpecs.isEmpty()) {
            for (Map.Entry<String, OnnxEvaluator.IdAndType> entry : outputTypes.entrySet()) {
                String name = entry.getKey();
                String as = entry.getValue().id();
                TensorType tType = entry.getValue().type();
                OutputSpec spec = new OutputSpec(name, as, tType);
                this.outputSpecs.add(spec);
            }
        } else {
            if (outputTypes.size() != this.outputSpecs.size()) {
                throw new IllegalArgumentException("Onnx model " + this.name() + ": Mismatch between " + this.outputSpecs.size() + " configured outputs and " + outputTypes.size() + " actual model outputs");
            }
            for (OutputSpec spec : this.outputSpecs) {
                OnnxEvaluator.IdAndType entry = outputTypes.get(spec.onnxName);
                if (entry == null) {
                    throw new IllegalArgumentException("Onnx model " + this.name() + ": No type in actual model for configured output " + spec.onnxName);
                }
                spec.expectedType = entry.type();
            }
        }
    }

    public Map<String, TensorType> inputs() {
        HashMap<String, TensorType> map = new HashMap<String, TensorType>();
        for (InputSpec spec : this.inputSpecs) {
            map.put(spec.source, spec.wantedType);
        }
        return map;
    }

    public Map<String, TensorType> outputs() {
        HashMap<String, TensorType> map = new HashMap<String, TensorType>();
        for (OutputSpec spec : this.outputSpecs) {
            map.put(spec.outputAs, spec.expectedType);
        }
        return map;
    }

    public Tensor evaluate(Map<String, Tensor> inputs, String output) {
        HashMap<String, Tensor> mapped = new HashMap<String, Tensor>();
        for (InputSpec spec : this.inputSpecs) {
            Tensor val = inputs.get(spec.source);
            if (val == null) {
                throw new IllegalArgumentException("evaluate ONNX model " + this.name() + ": missing input from source " + spec.source);
            }
            mapped.put(spec.onnxName, val);
        }
        String onnxName = null;
        for (OutputSpec spec : this.outputSpecs) {
            if (!spec.outputAs.equals(output)) continue;
            onnxName = spec.onnxName;
        }
        if (onnxName == null) {
            throw new IllegalArgumentException("evaluate ONNX model " + this.name() + ": no output available as: " + output);
        }
        return this.evaluator().evaluate(mapped, onnxName);
    }

    private OnnxEvaluator evaluator() {
        if (this.evaluator == null) {
            throw new IllegalStateException("ONNX model has not been loaded.");
        }
        return this.evaluator;
    }

    @Override
    public void close() {
        this.evaluator.close();
    }

    static class InputSpec {
        String onnxName;
        String source;
        TensorType wantedType;

        InputSpec(String name, String source, TensorType tType) {
            this.onnxName = name;
            this.source = source;
            this.wantedType = tType;
        }

        InputSpec(String name, String source) {
            this(name, source, null);
        }
    }

    static class OutputSpec {
        String onnxName;
        String outputAs;
        TensorType expectedType;

        OutputSpec(String name, String as, TensorType tType) {
            this.onnxName = name;
            this.outputAs = as;
            this.expectedType = tType;
        }

        OutputSpec(String name, String as) {
            this(name, as, null);
        }
    }
}

