/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.triton;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.triton.TritonOnnxClient;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

class TritonOnnxEvaluator
implements OnnxEvaluator {
    private static final Logger log = Logger.getLogger(TritonOnnxEvaluator.class.getName());
    private final String modelName;
    private final TritonOnnxClient tritonClient;
    private final boolean isExplicitControlMode;
    private TritonOnnxClient.ModelMetadata modelMetadata;

    TritonOnnxEvaluator(TritonOnnxClient tritonClient, String modelName, boolean isExplicitControlMode) {
        this.modelName = modelName;
        this.tritonClient = tritonClient;
        this.isExplicitControlMode = isExplicitControlMode;
        this.loadModelIfNotReady();
    }

    private void loadModelIfNotReady() {
        boolean isModelReady;
        if (this.isExplicitControlMode && !(isModelReady = this.tritonClient.isModelReady(this.modelName))) {
            this.tritonClient.loadModel(this.modelName);
        }
        this.modelMetadata = this.tritonClient.getModelMetadata(this.modelName);
    }

    @Override
    public Tensor evaluate(Map<String, Tensor> inputs, String output) {
        return this.evaluate(inputs).get(output);
    }

    @Override
    public Map<String, Tensor> evaluate(Map<String, Tensor> inputs) {
        return this.evaluate(inputs, true);
    }

    private Map<String, Tensor> evaluate(Map<String, Tensor> inputs, boolean allowRetry) {
        try {
            return this.tritonClient.evaluate(this.modelName, this.modelMetadata, inputs);
        }
        catch (TritonOnnxClient.TritonException e) {
            if (allowRetry) {
                log.warning("Retrying to evaluate model: " + this.modelName);
                this.loadModelIfNotReady();
                return this.evaluate(inputs, false);
            }
            throw e;
        }
    }

    @Override
    public Map<String, OnnxEvaluator.IdAndType> getInputs() {
        HashMap<String, OnnxEvaluator.IdAndType> result = new HashMap<String, OnnxEvaluator.IdAndType>();
        this.modelMetadata.inputs.forEach((name, type) -> result.put((String)name, new OnnxEvaluator.IdAndType((String)name, (TensorType)type)));
        return result;
    }

    @Override
    public Map<String, OnnxEvaluator.IdAndType> getOutputs() {
        HashMap<String, OnnxEvaluator.IdAndType> result = new HashMap<String, OnnxEvaluator.IdAndType>();
        this.modelMetadata.outputs.forEach((name, type) -> result.put((String)name, new OnnxEvaluator.IdAndType((String)name, (TensorType)type)));
        return result;
    }

    @Override
    public Map<String, TensorType> getInputInfo() {
        return this.modelMetadata.inputs;
    }

    @Override
    public Map<String, TensorType> getOutputInfo() {
        return this.modelMetadata.outputs;
    }

    @Override
    public void close() {
        if (this.isExplicitControlMode) {
            this.tritonClient.unloadModel(this.modelName);
        }
    }
}

