/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.model.ml;

import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.ml.OnnxModelProbe;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import onnx.Onnx;

public class OnnxModelInfo {
    private final ApplicationPackage app;
    private final String modelPath;
    private final String defaultOutput;
    private final Map<String, OnnxTypeInfo> inputs;
    private final Map<String, OnnxTypeInfo> outputs;
    private final Map<String, TensorType> vespaTypes = new HashMap<String, TensorType>();
    private final Set<String> initializers;

    private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, Set<String> initializers, String defaultOutput) {
        this.app = app;
        this.modelPath = path;
        this.inputs = Collections.unmodifiableMap(inputs);
        this.outputs = Collections.unmodifiableMap(outputs);
        this.defaultOutput = defaultOutput;
        this.initializers = Set.copyOf(initializers);
    }

    public String getModelPath() {
        return this.modelPath;
    }

    public Set<String> getInputs() {
        return this.inputs.keySet();
    }

    public Set<String> getOutputs() {
        return this.outputs.keySet();
    }

    public Set<String> getInitializers() {
        return this.initializers;
    }

    public String getDefaultOutput() {
        return this.defaultOutput;
    }

    public TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
        OnnxTypeInfo onnxTypeInfo = this.outputs.get(onnxName);
        if (onnxTypeInfo == null) {
            throw new IllegalArgumentException("Could not find type for output '" + onnxName + "'");
        }
        if (onnxTypeInfo.containsUnknownDimensionSizes()) {
            HashSet<Long> unboundSizes = new HashSet<Long>();
            HashMap<String, Long> symbolicSizes = new HashMap<String, Long>();
            this.resolveUnknownDimensionSizes(inputTypes, symbolicSizes, unboundSizes);
            TensorType type = TensorType.empty;
            if (inputTypes.size() > 0 && onnxTypeInfo.needModelProbe(symbolicSizes)) {
                type = OnnxModelProbe.probeModel(this.app, Path.fromString((String)this.modelPath), onnxName, inputTypes);
            }
            if (type.equals((Object)TensorType.empty)) {
                type = onnxTypeInfo.toVespaTensorType(symbolicSizes, unboundSizes);
            }
            return type;
        }
        return this.vespaTypes.computeIfAbsent(onnxName, v -> onnxTypeInfo.toVespaTensorType());
    }

    private void resolveUnknownDimensionSizes(Map<String, TensorType> inputTypes, Map<String, Long> symbolicSizes, Set<Long> unboundSizes) {
        for (Map.Entry<String, OnnxTypeInfo> input : this.inputs.entrySet()) {
            String onnxName = input.getKey();
            OnnxTypeInfo onnxType = input.getValue();
            TensorType vespaType = inputTypes.get(onnxName);
            if (vespaType == null || vespaType.dimensions().size() != onnxType.dimensions().size()) continue;
            for (int i = 0; i < vespaType.dimensions().size(); ++i) {
                if (((TensorType.Dimension)vespaType.dimensions().get(i)).size().isEmpty()) continue;
                Long size = (Long)((TensorType.Dimension)vespaType.dimensions().get(i)).size().get();
                if (onnxType.dimensions().get(i).getSize() == -1L) {
                    unboundSizes.add(size);
                    if (unboundSizes.size() <= 1) continue;
                    throw new IllegalArgumentException("Found conflicting sizes for unbound dimension for type '" + onnxType + "'");
                }
                if (!onnxType.dimensions().get(i).hasSymbolicName()) continue;
                String symbolicName = onnxType.dimensions().get(i).getSymbolicName();
                if (symbolicSizes.containsKey(symbolicName) && !symbolicSizes.get(symbolicName).equals(size)) {
                    throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" + symbolicName + "' for input '" + onnxName + "'");
                }
                symbolicSizes.put(symbolicName, size);
            }
        }
    }

    public static OnnxModelInfo load(String path, ApplicationPackage app) {
        Path pathInApplicationPackage = Path.fromString((String)path);
        if (app.getFile(pathInApplicationPackage).exists()) {
            return OnnxModelInfo.loadFromFile(pathInApplicationPackage, app);
        }
        if (app.getFile(OnnxModelInfo.generatedModelInfoPath(pathInApplicationPackage)).exists()) {
            return OnnxModelInfo.loadFromGeneratedInfo(pathInApplicationPackage, app);
        }
        throw new IllegalArgumentException("Unable to find ONNX model '" + path + "'");
    }

    public static boolean modelExists(String path, ApplicationPackage app) {
        Path pathInApplicationPackage = Path.fromString((String)path);
        if (app.getFile(pathInApplicationPackage).exists()) {
            return true;
        }
        return app.getFile(OnnxModelInfo.generatedModelInfoPath(Path.fromString((String)path))).exists();
    }

    private static OnnxModelInfo loadFromFile(Path path, ApplicationPackage app) {
        OnnxModelInfo onnxModelInfo;
        block8: {
            InputStream inputStream = app.getFile(path).createInputStream();
            try {
                Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
                String json = OnnxModelInfo.onnxModelToJson(model, path);
                OnnxModelInfo.storeGeneratedInfo(json, path, app);
                onnxModelInfo = OnnxModelInfo.jsonToModelInfo(json, app);
                if (inputStream == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (inputStream != null) {
                        try {
                            inputStream.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new IllegalArgumentException("Unable to parse ONNX model", e);
                }
            }
            inputStream.close();
        }
        return onnxModelInfo;
    }

    private static OnnxModelInfo loadFromGeneratedInfo(Path path, ApplicationPackage app) {
        try {
            String json = OnnxModelInfo.readGeneratedInfo(path, app);
            return OnnxModelInfo.jsonToModelInfo(json, app);
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Unable to parse ONNX model", e);
        }
    }

    private static String readGeneratedInfo(Path path, ApplicationPackage app) throws IOException {
        ApplicationFile file = app.getFile(OnnxModelInfo.generatedModelInfoPath(path));
        return IOUtils.readAll((Reader)file.createReader());
    }

    private static void storeGeneratedInfo(String json, Path path, ApplicationPackage app) throws IOException {
        IOUtils.writeFile((File)app.getFileReference(OnnxModelInfo.generatedModelInfoPath(path)), (String)json, (boolean)false);
    }

    private static Path generatedModelInfoPath(Path path) {
        String fileName = OnnxModelInfo.asValidIdentifier(path.getRelative()) + ".modelinfo.json";
        return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName);
    }

    private static String onnxModelToJson(Onnx.ModelProto model, Path path) throws IOException {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        JsonGenerator g = new JsonFactory().createGenerator((OutputStream)out, JsonEncoding.UTF8);
        g.writeStartObject();
        g.writeStringField("path", path.toString());
        g.writeArrayFieldStart("inputs");
        for (Onnx.ValueInfoProto valueInfo : model.getGraph().getInputList()) {
            OnnxModelInfo.onnxTypeToJson(g, valueInfo);
        }
        g.writeEndArray();
        g.writeArrayFieldStart("outputs");
        for (Onnx.ValueInfoProto valueInfo : model.getGraph().getOutputList()) {
            OnnxModelInfo.onnxTypeToJson(g, valueInfo);
        }
        g.writeEndArray();
        g.writeArrayFieldStart("initializers");
        for (Onnx.TensorProto initializers : model.getGraph().getInitializerList()) {
            g.writeStartObject();
            g.writeStringField("name", initializers.getName());
            g.writeEndObject();
        }
        g.writeEndArray();
        g.writeEndObject();
        g.close();
        return out.toString();
    }

    public static OnnxModelInfo jsonToModelInfo(String json, ApplicationPackage app) throws IOException {
        JsonNode initializerRoot;
        ObjectMapper m = new ObjectMapper();
        JsonNode root = m.readTree(json);
        HashMap<String, OnnxTypeInfo> inputs = new HashMap<String, OnnxTypeInfo>();
        HashMap<String, OnnxTypeInfo> outputs = new HashMap<String, OnnxTypeInfo>();
        HashSet<String> initializers = new HashSet<String>();
        String defaultOutput = "";
        String path = null;
        if (root.has("path")) {
            path = root.get("path").textValue();
        }
        for (JsonNode input : root.get("inputs")) {
            inputs.put(input.get("name").textValue(), OnnxModelInfo.jsonToTypeInfo(input));
        }
        for (JsonNode output : root.get("outputs")) {
            outputs.put(output.get("name").textValue(), OnnxModelInfo.jsonToTypeInfo(output));
        }
        if (root.get("outputs").has(0)) {
            defaultOutput = root.get("outputs").get(0).get("name").textValue();
        }
        if ((initializerRoot = root.get("initializers")) != null) {
            for (JsonNode initializer : initializerRoot) {
                initializers.add(initializer.get("name").textValue());
            }
        }
        return new OnnxModelInfo(app, path, inputs, outputs, initializers, defaultOutput);
    }

    private static void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {
        g.writeStartObject();
        g.writeStringField("name", valueInfo.getName());
        g.writeStringField("type", OnnxModelInfo.onnxValueTypeToString(valueInfo.getType().getTensorType().getElemType()));
        g.writeArrayFieldStart("dim");
        for (Onnx.TensorShapeProto.Dimension dim : valueInfo.getType().getTensorType().getShape().getDimList()) {
            g.writeStartObject();
            if (dim.hasDimParam()) {
                g.writeStringField("type", "param");
                g.writeStringField("size", dim.getDimParam());
            } else {
                g.writeStringField("type", "value");
                g.writeNumberField("size", dim.getDimValue());
            }
            g.writeEndObject();
        }
        g.writeEndArray();
        g.writeEndObject();
    }

    private static OnnxTypeInfo jsonToTypeInfo(JsonNode node) {
        TensorType.Value valueType = OnnxModelInfo.stringToValueType(node.get("type").textValue());
        OnnxTypeInfo type = new OnnxTypeInfo(valueType);
        for (JsonNode dim : node.get("dim")) {
            if (dim.get("type").textValue().equals("param")) {
                type.addDimension(dim.get("size").textValue());
                continue;
            }
            type.addDimension(dim.get("size").longValue());
        }
        return type;
    }

    private static String onnxValueTypeToString(Onnx.TensorProto.DataType dataType) {
        switch (dataType) {
            case FLOAT: {
                return "float";
            }
            case DOUBLE: {
                return "double";
            }
            case BOOL: {
                return "float";
            }
            case INT8: {
                return "float";
            }
            case INT16: {
                return "float";
            }
            case INT32: {
                return "float";
            }
            case INT64: {
                return "float";
            }
            case UINT8: {
                return "float";
            }
            case UINT16: {
                return "float";
            }
            case UINT32: {
                return "float";
            }
            case UINT64: {
                return "float";
            }
        }
        throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + " cannot be converted to a Vespa tensor type");
    }

    private static TensorType.Value stringToValueType(String type) {
        switch (type) {
            case "float": {
                return TensorType.Value.FLOAT;
            }
            case "double": {
                return TensorType.Value.DOUBLE;
            }
        }
        throw new IllegalArgumentException("Unknown tensor value type: " + type);
    }

    public static String asValidIdentifier(String str) {
        return str.replaceAll("[^\\w\\d\\$@_]", "_");
    }

    private static class OnnxTypeInfo {
        private final TensorType.Value valueType;
        private final List<OnnxDimensionInfo> dimensions = new ArrayList<OnnxDimensionInfo>();

        OnnxTypeInfo(TensorType.Value valueType) {
            this.valueType = valueType;
        }

        void addDimension(long value) {
            this.dimensions.add(new OnnxDimensionInfo(value));
        }

        void addDimension(String param) {
            this.dimensions.add(new OnnxDimensionInfo(param));
        }

        boolean containsUnknownDimensionSizes() {
            return this.dimensions.stream().anyMatch(OnnxDimensionInfo::unknownDimensionSize);
        }

        TensorType.Value valueType() {
            return this.valueType;
        }

        List<OnnxDimensionInfo> dimensions() {
            return this.dimensions;
        }

        TensorType toVespaTensorType() {
            return this.toVespaTensorType(null, null);
        }

        TensorType toVespaTensorType(Map<String, Long> symbolicSizes, Set<Long> unboundSizes) {
            String dimensionPrefix = "d";
            TensorType.Builder builder = new TensorType.Builder(this.valueType);
            for (int i = 0; i < this.dimensions.size(); ++i) {
                HashSet<Long> unknownSizes;
                String dimensionName = dimensionPrefix + i;
                OnnxDimensionInfo onnxDimension = this.dimensions.get(i);
                long onnxDimensionSize = onnxDimension.getSize();
                if (onnxDimension.hasSymbolicName() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getSymbolicName())) {
                    onnxDimensionSize = symbolicSizes.get(onnxDimension.getSymbolicName());
                }
                if (onnxDimensionSize == 0L && symbolicSizes != null && (unknownSizes = new HashSet<Long>(symbolicSizes.values())).size() == 1) {
                    onnxDimensionSize = (Long)unknownSizes.iterator().next();
                }
                if (onnxDimensionSize < 0L && unboundSizes != null && unboundSizes.size() > 0) {
                    onnxDimensionSize = unboundSizes.iterator().next();
                }
                if (onnxDimensionSize <= 0L) {
                    return TensorType.empty;
                }
                builder.indexed(dimensionName, onnxDimensionSize);
            }
            return builder.build();
        }

        boolean needModelProbe(Map<String, Long> symbolicSizes) {
            for (OnnxDimensionInfo onnxDimension : this.dimensions) {
                if (onnxDimension.hasSymbolicName()) {
                    if (symbolicSizes == null) {
                        return true;
                    }
                    if (symbolicSizes.containsKey(onnxDimension.getSymbolicName())) continue;
                    return true;
                }
                if (onnxDimension.getSize() != 0L) continue;
                return true;
            }
            return false;
        }

        public String toString() {
            return "(" + this.valueType.id() + ")[" + this.dimensions.stream().map(OnnxDimensionInfo::toString).collect(Collectors.joining(",")) + "]";
        }
    }

    private static class OnnxDimensionInfo {
        private final long size;
        private final String symbolicName;

        OnnxDimensionInfo(long size) {
            this.size = size;
            this.symbolicName = null;
        }

        OnnxDimensionInfo(String symbolicName) {
            this.size = 0L;
            this.symbolicName = symbolicName;
        }

        long getSize() {
            return this.size;
        }

        String getSymbolicName() {
            return this.symbolicName;
        }

        boolean hasSymbolicName() {
            return this.symbolicName != null;
        }

        boolean unknownDimensionSize() {
            return this.hasSymbolicName() || this.size <= 0L;
        }

        public String toString() {
            return this.hasSymbolicName() ? "\"" + this.symbolicName + "\"" : Long.toString(this.size);
        }
    }
}

