/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.triton;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.triton.TritonOnnxClient;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

class TritonOnnxEvaluator
implements OnnxEvaluator {
    private static final Logger log = Logger.getLogger(TritonOnnxEvaluator.class.getName());
    private final String modelName;
    private final TritonOnnxClient triton;
    private final boolean isExplicitControlMode;
    private TritonOnnxClient.ModelMetadata modelMetadata;

    TritonOnnxEvaluator(TritonOnnxClient client, String modelName, boolean isExplicitControlMode) {
        this.modelName = modelName;
        this.triton = client;
        this.isExplicitControlMode = isExplicitControlMode;
        this.loadModel();
    }

    private void loadModel() {
        try {
            if (this.isExplicitControlMode) {
                this.triton.loadModel(this.modelName);
            }
            this.modelMetadata = this.triton.getModelMetadata(this.modelName);
        }
        catch (TritonOnnxClient.TritonException e) {
            throw new RuntimeException("Failed to load model: " + this.modelName, e);
        }
    }

    @Override
    public Tensor evaluate(Map<String, Tensor> inputs, String output) {
        return this.evaluate(inputs).get(output);
    }

    @Override
    public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
        return this.evaluate(inputs, true);
    }

    private Map<String, Tensor> evaluate(Map<String, Tensor> inputs, boolean allowRetry) {
        try {
            return this.triton.evaluate(this.modelName, inputs);
        }
        catch (TritonOnnxClient.TritonException e) {
            if (allowRetry) {
                log.warning(() -> "Retrying to evaluate model: " + this.modelName);
                this.loadModel();
                return this.evaluate(inputs, false);
            }
            throw new RuntimeException("Failed to evaluate model: " + this.modelName, e);
        }
    }

    @Override
    public Map<String, OnnxEvaluator.IdAndType> getInputs() {
        HashMap<String, OnnxEvaluator.IdAndType> result = new HashMap<String, OnnxEvaluator.IdAndType>();
        this.modelMetadata.inputs().forEach((name, type) -> result.put((String)name, new OnnxEvaluator.IdAndType((String)name, (TensorType)type)));
        return result;
    }

    @Override
    public Map<String, OnnxEvaluator.IdAndType> getOutputs() {
        HashMap<String, OnnxEvaluator.IdAndType> result = new HashMap<String, OnnxEvaluator.IdAndType>();
        this.modelMetadata.outputs().forEach((name, type) -> result.put((String)name, new OnnxEvaluator.IdAndType((String)name, (TensorType)type)));
        return result;
    }

    @Override
    public Map<String, TensorType> getInputInfo() {
        return this.modelMetadata.inputs();
    }

    @Override
    public Map<String, TensorType> getOutputInfo() {
        return this.modelMetadata.outputs();
    }

    @Override
    public void close() {
        if (this.isExplicitControlMode) {
            this.triton.unloadModel(this.modelName);
        }
    }
}

