/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchdefinition.expressiontransforms;

import com.yahoo.searchdefinition.ImmutableSearch;
import com.yahoo.searchdefinition.OnnxModel;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import java.util.List;

public class OnnxModelTransformer
extends ExpressionTransformer<RankProfileTransformContext> {
    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof ReferenceNode) {
            return this.transformFeature((ReferenceNode)node, context);
        }
        if (node instanceof CompositeNode) {
            return super.transformChildren((CompositeNode)node, (TransformContext)context);
        }
        return node;
    }

    private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
        if (context.rankProfile() == null) {
            return feature;
        }
        if (context.rankProfile().getSearch() == null) {
            return feature;
        }
        return OnnxModelTransformer.transformFeature(feature, context.rankProfile().getSearch());
    }

    public static ReferenceNode transformFeature(ReferenceNode feature, ImmutableSearch search) {
        String name;
        OnnxModel onnxModel;
        String modelConfigName;
        if (!feature.getName().equals("onnxModel")) {
            return feature;
        }
        Arguments arguments = feature.getArguments();
        if (arguments.isEmpty()) {
            throw new IllegalArgumentException("An onnxModel feature must take an argument referring to a onnx-model config or a ONNX file.");
        }
        if (arguments.expressions().size() > 2) {
            throw new IllegalArgumentException("An onnxModel feature can have at most 2 arguments.");
        }
        if (arguments.expressions().get(0) instanceof ReferenceNode) {
            modelConfigName = ((ExpressionNode)arguments.expressions().get(0)).toString();
            onnxModel = search.onnxModels().get(modelConfigName);
            if (onnxModel == null) {
                throw new IllegalArgumentException("onnxModel argument '" + modelConfigName + "' config not found");
            }
        } else if (arguments.expressions().get(0) instanceof ConstantNode) {
            String path = OnnxModelTransformer.asString((ExpressionNode)arguments.expressions().get(0));
            modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
            onnxModel = search.onnxModels().get(modelConfigName);
            if (onnxModel == null) {
                onnxModel = new OnnxModel(modelConfigName, path);
                search.onnxModels().add(onnxModel);
            }
        } else {
            throw new IllegalArgumentException("Illegal argument to onnxModel: '" + arguments.expressions().get(0) + "'");
        }
        String output = null;
        if (feature.getOutput() != null) {
            output = feature.getOutput();
            if (!OnnxModelTransformer.hasOutputMapping(onnxModel, output)) {
                onnxModel.addOutputNameMapping(output, output);
            }
        } else if (arguments.expressions().size() > 1 && !OnnxModelTransformer.hasOutputMapping(onnxModel, output = OnnxModelTransformer.asValidIdentifier(name = OnnxModelTransformer.asString((ExpressionNode)arguments.expressions().get(1))))) {
            onnxModel.addOutputNameMapping(name, output);
        }
        ReferenceNode argument = new ReferenceNode(modelConfigName);
        return new ReferenceNode("onnxModel", List.of(argument), output);
    }

    private static boolean hasOutputMapping(OnnxModel onnxModel, String as) {
        return onnxModel.getOutputMap().stream().anyMatch(m -> m.getVespaName().equals(as));
    }

    private static String asString(ExpressionNode node) {
        if (!(node instanceof ConstantNode)) {
            throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
        }
        return OnnxModelTransformer.stripQuotes(((ConstantNode)node).sourceString());
    }

    private static String stripQuotes(String s) {
        if (!OnnxModelTransformer.isQuoteSign(s.codePointAt(0))) {
            return s;
        }
        if (!OnnxModelTransformer.isQuoteSign(s.codePointAt(s.length() - 1))) {
            throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
        }
        return s.substring(1, s.length() - 1);
    }

    private static boolean isQuoteSign(int c) {
        return c == 39 || c == 34;
    }

    private static String asValidIdentifier(String str) {
        return str.replaceAll("[^\\w\\d\\$@_]", "_");
    }
}

