/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.model.ml;

import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.ml.OnnxModelInfo;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Map;

public class OnnxModelProbe {
    private static final String binary = "vespa-analyze-onnx-model";

    static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) {
        TensorType outputType = TensorType.empty;
        String contextKey = OnnxModelProbe.createContextKey(outputName, inputTypes);
        try {
            String jsonInput;
            String jsonOutput;
            outputType = OnnxModelProbe.readProbedOutputType(app, modelPath, contextKey);
            if (outputType.equals((Object)TensorType.empty) && app.getFile(modelPath).exists() && !(outputType = OnnxModelProbe.outputTypeFromJson(jsonOutput = OnnxModelProbe.callVespaAnalyzeOnnxModel(jsonInput = OnnxModelProbe.createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes)), outputName)).equals((Object)TensorType.empty)) {
                OnnxModelProbe.writeProbedOutputType(app, modelPath, contextKey, outputType);
            }
        }
        catch (IOException | IllegalArgumentException | InterruptedException e) {
            e.printStackTrace(System.err);
        }
        return outputType;
    }

    private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) {
        StringBuilder key = new StringBuilder().append(onnxName).append(":");
        inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()).forEachOrdered(e -> key.append((String)e.getKey()).append(":").append(e.getValue()).append(","));
        return key.substring(0, key.length() - 1);
    }

    private static Path probedOutputTypesPath(Path path) {
        String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".probed_output_types";
        return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
    }

    static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String output, Map<String, TensorType> inputTypes, TensorType type) throws IOException {
        OnnxModelProbe.writeProbedOutputType(app, modelPath, OnnxModelProbe.createContextKey(output, inputTypes), type);
    }

    private static void writeProbedOutputType(ApplicationPackage app, Path modelPath, String contextKey, TensorType type) throws IOException {
        String path = app.getFileReference(OnnxModelProbe.probedOutputTypesPath(modelPath)).getAbsolutePath();
        IOUtils.writeFile((String)path, (String)(contextKey + "\t" + type + "\n"), (boolean)true);
    }

    private static TensorType readProbedOutputType(ApplicationPackage app, Path modelPath, String contextKey) throws IOException {
        ApplicationFile file = app.getFile(OnnxModelProbe.probedOutputTypesPath(modelPath));
        if (!file.exists()) {
            return TensorType.empty;
        }
        try (BufferedReader reader = new BufferedReader(file.createReader());){
            String line;
            while (null != (line = reader.readLine())) {
                String[] parts = line.split("\t");
                String key = parts[0];
                if (!key.equals(contextKey)) continue;
                TensorType tensorType = TensorType.fromSpec((String)parts[1]);
                return tensorType;
            }
        }
        return TensorType.empty;
    }

    private static TensorType outputTypeFromJson(String json, String outputName) throws IOException {
        ObjectMapper m = new ObjectMapper();
        JsonNode root = m.readTree(json);
        if (!root.isObject() || !root.has("outputs")) {
            return TensorType.empty;
        }
        JsonNode outputs = root.get("outputs");
        if (!outputs.has(outputName)) {
            return TensorType.empty;
        }
        return TensorType.fromSpec((String)outputs.get(outputName).asText());
    }

    private static String createJsonInput(String modelPath, Map<String, TensorType> inputTypes) throws IOException {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        JsonGenerator g = new JsonFactory().createGenerator((OutputStream)out, JsonEncoding.UTF8);
        g.writeStartObject();
        g.writeStringField("model", modelPath);
        g.writeObjectFieldStart("inputs");
        for (Map.Entry<String, TensorType> input : inputTypes.entrySet()) {
            g.writeStringField(input.getKey(), input.getValue().toString());
        }
        g.writeEndObject();
        g.writeEndObject();
        g.close();
        return out.toString();
    }

    private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException {
        int b;
        StringBuilder output = new StringBuilder();
        ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types");
        processBuilder.redirectError(ProcessBuilder.Redirect.DISCARD);
        Process process = processBuilder.start();
        OutputStream os = process.getOutputStream();
        os.write(jsonInput.getBytes(StandardCharsets.UTF_8));
        os.close();
        InputStream inputStream = process.getInputStream();
        while ((b = inputStream.read()) != -1) {
            output.append((char)b);
        }
        int returnCode = process.waitFor();
        if (returnCode != 0) {
            throw new IllegalArgumentException("Error from 'vespa-analyze-onnx-model'. Return code: " + returnCode + ". Output:\n" + output);
        }
        return output.toString();
    }
}

