/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.schema;

import com.yahoo.schema.DistributableResource;
import com.yahoo.schema.FeatureNames;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.ml.OnnxModelInfo;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class OnnxModel
extends DistributableResource {
    private OnnxModelInfo modelInfo = null;
    private final Map<String, String> inputMap = new HashMap<String, String>();
    private final Map<String, String> outputMap = new HashMap<String, String>();
    private String statelessExecutionMode = null;
    private Integer statelessInterOpThreads = null;
    private Integer statelessIntraOpThreads = null;
    private GpuDevice gpuDevice = null;

    public OnnxModel(String name) {
        super(name);
    }

    public OnnxModel(String name, String fileName) {
        super(name, fileName);
        this.validate();
    }

    @Override
    public void setUri(String uri) {
        throw new IllegalArgumentException("URI for ONNX models are not currently supported");
    }

    public void addInputNameMapping(String onnxName, String vespaName) {
        this.addInputNameMapping(onnxName, vespaName, true);
    }

    private String validateInputSource(String source) {
        Optional optRef = Reference.simple((String)source);
        if (optRef.isPresent()) {
            Optional arg;
            Reference ref = (Reference)optRef.get();
            if (FeatureNames.isSimpleFeature(ref)) {
                return ref.toString();
            }
            if (ref.isSimple() && "rankingExpression".equals(ref.name()) && (arg = ref.simpleArgument()).isPresent()) {
                return ref.toString();
            }
        } else {
            Reference ref = Reference.fromIdentifier((String)source);
            return ref.toString();
        }
        throw new IllegalArgumentException("invalid input for ONNX model " + this.getName() + ": " + source);
    }

    public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(vespaName, "Vespa name cannot be null");
        String source = this.validateInputSource(vespaName);
        if (overwrite || !this.inputMap.containsKey(onnxName)) {
            this.inputMap.put(onnxName, source);
        }
    }

    public void addOutputNameMapping(String onnxName, String vespaName) {
        this.addOutputNameMapping(onnxName, vespaName, true);
    }

    public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(vespaName, "Vespa name cannot be null");
        Reference ref = Reference.fromIdentifier((String)vespaName);
        if (overwrite || !this.outputMap.containsKey(onnxName)) {
            this.outputMap.put(onnxName, ref.toString());
        }
    }

    public void setModelInfo(OnnxModelInfo modelInfo) {
        Objects.requireNonNull(modelInfo, "Onnx model info cannot be null");
        for (String onnxName : modelInfo.getInputs()) {
            this.addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
        }
        for (String onnxName : modelInfo.getOutputs()) {
            this.addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
        }
        this.modelInfo = modelInfo;
    }

    public Map<String, String> getInputMap() {
        return Collections.unmodifiableMap(this.inputMap);
    }

    public Map<String, String> getOutputMap() {
        return Collections.unmodifiableMap(this.outputMap);
    }

    public String getDefaultOutput() {
        return this.modelInfo != null ? this.modelInfo.getDefaultOutput() : "";
    }

    TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
        return this.modelInfo != null ? this.modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty;
    }

    public void setStatelessExecutionMode(String executionMode) {
        if ("parallel".equalsIgnoreCase(executionMode)) {
            this.statelessExecutionMode = "parallel";
        } else if ("sequential".equalsIgnoreCase(executionMode)) {
            this.statelessExecutionMode = "sequential";
        }
    }

    public Optional<String> getStatelessExecutionMode() {
        return Optional.ofNullable(this.statelessExecutionMode);
    }

    public void setStatelessInterOpThreads(int interOpThreads) {
        if (interOpThreads >= 0) {
            this.statelessInterOpThreads = interOpThreads;
        }
    }

    public Optional<Integer> getStatelessInterOpThreads() {
        return Optional.ofNullable(this.statelessInterOpThreads);
    }

    public void setStatelessIntraOpThreads(int intraOpThreads) {
        if (intraOpThreads >= 0) {
            this.statelessIntraOpThreads = intraOpThreads;
        }
    }

    public void setGpuDevice(int deviceNumber, boolean required) {
        if (deviceNumber >= 0) {
            this.gpuDevice = new GpuDevice(deviceNumber, required);
        }
    }

    public Optional<Integer> getStatelessIntraOpThreads() {
        return Optional.ofNullable(this.statelessIntraOpThreads);
    }

    public Optional<GpuDevice> getGpuDevice() {
        return Optional.ofNullable(this.gpuDevice);
    }

    public record GpuDevice(int deviceNumber, boolean required) {
        public GpuDevice {
            if (deviceNumber < 0) {
                throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
            }
        }
    }
}

