/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.schema.expressiontransforms;

import com.yahoo.path.Path;
import com.yahoo.schema.OnnxModel;
import com.yahoo.schema.RankProfile;
import com.yahoo.schema.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.vespa.model.ml.ConvertedModel;
import com.yahoo.vespa.model.ml.FeatureArguments;
import com.yahoo.vespa.model.ml.ModelName;
import java.util.List;

public class OnnxModelTransformer
extends ExpressionTransformer<RankProfileTransformContext> {
    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof ReferenceNode) {
            return this.transformFeature((ReferenceNode)node, context);
        }
        if (node instanceof CompositeNode) {
            return super.transformChildren((CompositeNode)node, (TransformContext)context);
        }
        return node;
    }

    private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
        if (context.rankProfile() == null) {
            return feature;
        }
        if (context.rankProfile().schema() == null) {
            return feature;
        }
        return OnnxModelTransformer.transformFeature(feature, context.rankProfile());
    }

    public static ExpressionNode transformFeature(ReferenceNode feature, RankProfile profile) {
        String featureName = feature.getName();
        if (!featureName.equals("onnxModel") && !featureName.equals("onnx")) {
            return feature;
        }
        Arguments arguments = feature.getArguments();
        if (arguments.isEmpty()) {
            throw new IllegalArgumentException("An " + featureName + " feature must take an argument referring to a onnx-model config or an ONNX file.");
        }
        if (arguments.expressions().size() > 3) {
            throw new IllegalArgumentException("An " + featureName + " feature can have at most 3 arguments.");
        }
        String modelConfigName = OnnxModelTransformer.getModelConfigName(feature.reference());
        OnnxModel onnxModel = profile.onnxModels().get(modelConfigName);
        if (onnxModel == null) {
            String path = OnnxModelTransformer.asString((ExpressionNode)arguments.expressions().get(0));
            ModelName modelName = new ModelName(null, Path.fromString((String)path), true);
            ConvertedModel convertedModel = ConvertedModel.fromStore(profile.schema().applicationPackage(), modelName, path, profile);
            FeatureArguments featureArguments = new FeatureArguments(arguments);
            return convertedModel.expression(featureArguments, null);
        }
        String defaultOutput = onnxModel.getOutputMap().get(onnxModel.getDefaultOutput());
        String output = OnnxModelTransformer.getModelOutput(feature.reference(), defaultOutput);
        if (!onnxModel.getOutputMap().containsValue(output)) {
            throw new IllegalArgumentException(featureName + " argument '" + output + "' output not found in model '" + onnxModel.getFileName() + "'");
        }
        return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output);
    }

    public static String getModelConfigName(Reference reference) {
        if (reference.arguments().size() > 0) {
            ExpressionNode expr = (ExpressionNode)reference.arguments().expressions().get(0);
            if (expr instanceof ReferenceNode) {
                return expr.toString();
            }
            if (expr instanceof ConstantNode) {
                return OnnxModelTransformer.asValidIdentifier(expr);
            }
        }
        return null;
    }

    public static String getModelOutput(Reference reference, String defaultOutput) {
        if (reference.output() != null) {
            return reference.output();
        }
        if (reference.arguments().expressions().size() == 2) {
            return OnnxModelTransformer.asValidIdentifier((ExpressionNode)reference.arguments().expressions().get(1));
        }
        if (reference.arguments().expressions().size() > 2) {
            return OnnxModelTransformer.asValidIdentifier((ExpressionNode)reference.arguments().expressions().get(2));
        }
        return defaultOutput;
    }

    public static String stripQuotes(String s) {
        if (OnnxModelTransformer.isNotQuoteSign(s.codePointAt(0))) {
            return s;
        }
        if (OnnxModelTransformer.isNotQuoteSign(s.codePointAt(s.length() - 1))) {
            throw new IllegalArgumentException("argument [" + s + "] is missing end quote");
        }
        return s.substring(1, s.length() - 1);
    }

    public static String asValidIdentifier(String str) {
        return str.replaceAll("[^\\w\\d\\$@_]", "_");
    }

    private static String asValidIdentifier(ExpressionNode node) {
        return OnnxModelTransformer.asValidIdentifier(OnnxModelTransformer.asString(node));
    }

    private static boolean isNotQuoteSign(int c) {
        return c != 39 && c != 34;
    }

    public static String asString(ExpressionNode node) {
        if (!(node instanceof ConstantNode)) {
            throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
        }
        return OnnxModelTransformer.stripQuotes(node.toString());
    }
}

