/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchdefinition;

import com.yahoo.config.FileReference;
import com.yahoo.path.Path;
import com.yahoo.searchdefinition.MapEvaluationTypeContext;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.vespa.model.AbstractService;
import com.yahoo.vespa.model.utils.FileSender;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import onnx.Onnx;

public class OnnxModel {
    private final String name;
    private PathType pathType = PathType.FILE;
    private String path = null;
    private String fileReference = "";
    private String defaultOutput = null;
    private Map<String, String> inputMap = new HashMap<String, String>();
    private Map<String, String> outputMap = new HashMap<String, String>();
    private Map<String, Onnx.TypeProto> inputTypes = new HashMap<String, Onnx.TypeProto>();
    private Map<String, Onnx.TypeProto> outputTypes = new HashMap<String, Onnx.TypeProto>();
    private Map<String, TensorType> vespaTypes = new HashMap<String, TensorType>();

    public OnnxModel(String name) {
        this.name = name;
    }

    public OnnxModel(String name, String fileName) {
        this(name);
        this.path = fileName;
        this.validate();
    }

    public void setFileName(String fileName) {
        Objects.requireNonNull(fileName, "Filename cannot be null");
        this.path = fileName;
        this.pathType = PathType.FILE;
    }

    public void setUri(String uri) {
        throw new IllegalArgumentException("URI for ONNX models are not currently supported");
    }

    public PathType getPathType() {
        return this.pathType;
    }

    public void setDefaultOutput(String onnxName) {
        Objects.requireNonNull(onnxName, "Name cannot be null");
        this.defaultOutput = onnxName;
    }

    public void addInputNameMapping(String onnxName, String vespaName) {
        this.addInputNameMapping(onnxName, vespaName, true);
    }

    public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(vespaName, "Vespa name cannot be null");
        if (overwrite || !this.inputMap.containsKey(onnxName)) {
            this.inputMap.put(onnxName, vespaName);
        }
    }

    public void addOutputNameMapping(String onnxName, String vespaName) {
        this.addOutputNameMapping(onnxName, vespaName, true);
    }

    public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(vespaName, "Vespa name cannot be null");
        if (overwrite || !this.outputMap.containsKey(onnxName)) {
            this.outputMap.put(onnxName, vespaName);
        }
    }

    public void addInputType(String onnxName, Onnx.TypeProto type) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(type, "Tensor type cannot be null");
        this.inputTypes.put(onnxName, type);
    }

    public void addOutputType(String onnxName, Onnx.TypeProto type) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(type, "Tensor type cannot be null");
        this.outputTypes.put(onnxName, type);
    }

    public void sendTo(Collection<? extends AbstractService> services) {
        FileReference reference = this.pathType == PathType.FILE ? FileSender.sendFileToServices(this.path, services) : FileSender.sendUriToServices(this.path, services);
        this.fileReference = reference.value();
    }

    public String getName() {
        return this.name;
    }

    public String getFileName() {
        return this.path;
    }

    public Path getFilePath() {
        return Path.fromString((String)this.path);
    }

    public String getUri() {
        return this.path;
    }

    public String getFileReference() {
        return this.fileReference;
    }

    public Map<String, String> getInputMap() {
        return Collections.unmodifiableMap(this.inputMap);
    }

    public Map<String, String> getOutputMap() {
        return Collections.unmodifiableMap(this.outputMap);
    }

    public String getDefaultOutput() {
        return this.defaultOutput;
    }

    public void validate() {
        if (this.path == null || this.path.isEmpty()) {
            throw new IllegalArgumentException("ONNX models must have a file or uri.");
        }
    }

    public String toString() {
        StringBuilder b = new StringBuilder();
        b.append("onnx-model '").append(this.name).append(this.pathType == PathType.FILE ? "' from file '" : " from uri ").append(this.path).append("' with ref '").append(this.fileReference).append("'");
        return b.toString();
    }

    public TensorType getTensorType(String onnxName, MapEvaluationTypeContext context) {
        Onnx.TypeProto onnxOutputType = this.outputTypes.get(onnxName);
        if (onnxOutputType == null) {
            throw new IllegalArgumentException("Could not find type for output '" + onnxName + "' in '" + this.name + "'");
        }
        if (OnnxModel.allDimensionSizesAreKnown(onnxOutputType)) {
            return this.vespaTypes.computeIfAbsent(onnxName, v -> OnnxModel.typeFrom(onnxOutputType));
        }
        return this.getTensorTypeWithUnknownDimensions(onnxOutputType, context);
    }

    private static boolean allDimensionSizesAreKnown(Onnx.TypeProto type) {
        return type.getTensorType().getShape().getDimList().stream().noneMatch(d -> d.hasDimParam() && !d.hasDimValue() || d.getDimValue() == -1L);
    }

    private TensorType getTensorTypeWithUnknownDimensions(Onnx.TypeProto onnxOutputType, MapEvaluationTypeContext context) {
        long unboundSize = 0L;
        HashMap<String, Long> symbolicSizes = new HashMap<String, Long>();
        for (String onnxInputName : this.inputTypes.keySet()) {
            Onnx.TypeProto onnxType = this.inputTypes.get(onnxInputName);
            if (OnnxModel.allDimensionSizesAreKnown(onnxType)) continue;
            Optional<TensorType> vespaType = this.resolveInputType(onnxInputName, context);
            if (vespaType.isEmpty()) {
                return TensorType.empty;
            }
            List<Onnx.TensorShapeProto.Dimension> onnxDimensions = onnxType.getTensorType().getShape().getDimList();
            List vespaDimensions = vespaType.get().dimensions();
            if (vespaDimensions.size() != onnxDimensions.size()) {
                return TensorType.empty;
            }
            for (int i = 0; i < vespaDimensions.size(); ++i) {
                if (((TensorType.Dimension)vespaDimensions.get(i)).size().isEmpty()) continue;
                Long size = (Long)((TensorType.Dimension)vespaDimensions.get(i)).size().get();
                if (onnxDimensions.get(i).getDimValue() == -1L) {
                    if (unboundSize != 0L && unboundSize != size) {
                        throw new IllegalArgumentException("Found conflicting sizes for unbound dimension for type '" + onnxOutputType + "' in ONNX model '" + this.name + "'");
                    }
                    unboundSize = size;
                    continue;
                }
                if (!onnxDimensions.get(i).hasDimParam()) continue;
                String symbolicName = onnxDimensions.get(i).getDimParam();
                if (symbolicSizes.containsKey(symbolicName) && !((Long)symbolicSizes.get(symbolicName)).equals(size)) {
                    throw new IllegalArgumentException("Found conflicting sizes for symbolic dimension '" + symbolicName + "' for input '" + onnxInputName + "' in ONNX model '" + this.name + "'");
                }
                symbolicSizes.put(symbolicName, size);
            }
        }
        return OnnxModel.typeFrom(onnxOutputType, symbolicSizes, unboundSize);
    }

    private Optional<TensorType> resolveInputType(String onnxInputName, MapEvaluationTypeContext context) {
        String source = this.inputMap.get(onnxInputName);
        if (source != null) {
            Optional reference = Reference.simple((String)source);
            if (reference.isPresent()) {
                return Optional.of(context.getType((Reference)reference.get()));
            }
            ExpressionFunction func = context.getFunction(source);
            if (func != null) {
                return Optional.of(func.getBody().type((TypeContext)context));
            }
        }
        return Optional.empty();
    }

    private static TensorType typeFrom(Onnx.TypeProto type) {
        return OnnxModel.typeFrom(type, null, 0L);
    }

    private static TensorType typeFrom(Onnx.TypeProto type, Map<String, Long> symbolicSizes, long unboundSize) {
        String dimensionPrefix = "d";
        Onnx.TensorShapeProto shape = type.getTensorType().getShape();
        TensorType.Builder builder = new TensorType.Builder(OnnxModel.toValueType(type.getTensorType().getElemType()));
        for (int i = 0; i < shape.getDimCount(); ++i) {
            HashSet<Long> unknownSizes;
            String dimensionName = dimensionPrefix + i;
            Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
            long onnxDimensionSize = onnxDimension.getDimValue();
            if (onnxDimension.hasDimParam() && symbolicSizes != null && symbolicSizes.containsKey(onnxDimension.getDimParam())) {
                onnxDimensionSize = symbolicSizes.get(onnxDimension.getDimParam());
            }
            if (onnxDimensionSize == 0L && symbolicSizes != null && (unknownSizes = new HashSet<Long>(symbolicSizes.values())).size() == 1) {
                onnxDimensionSize = (Long)unknownSizes.iterator().next();
            }
            if (onnxDimensionSize < 0L) {
                onnxDimensionSize = unboundSize;
            }
            if (onnxDimensionSize <= 0L) {
                throw new IllegalArgumentException("Unable to determine fixed dimension size when converting from ONNX type: " + type + " to Vespa tensor type.");
            }
            builder.indexed(dimensionName, onnxDimensionSize);
        }
        return builder.build();
    }

    private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
        switch (dataType) {
            case FLOAT: {
                return TensorType.Value.FLOAT;
            }
            case DOUBLE: {
                return TensorType.Value.DOUBLE;
            }
            case BOOL: {
                return TensorType.Value.FLOAT;
            }
            case INT8: {
                return TensorType.Value.FLOAT;
            }
            case INT16: {
                return TensorType.Value.FLOAT;
            }
            case INT32: {
                return TensorType.Value.FLOAT;
            }
            case INT64: {
                return TensorType.Value.FLOAT;
            }
            case UINT8: {
                return TensorType.Value.FLOAT;
            }
            case UINT16: {
                return TensorType.Value.FLOAT;
            }
            case UINT32: {
                return TensorType.Value.FLOAT;
            }
            case UINT64: {
                return TensorType.Value.FLOAT;
            }
        }
        throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + " cannot be converted to a Vespa tensor type");
    }

    public static enum PathType {
        FILE,
        URI;

    }
}

