/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchdefinition.expressiontransforms;

import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.expressiontransforms.MLImportFeatureConverter;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import java.io.UncheckedIOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

public class OnnxFeatureConverter
extends MLImportFeatureConverter {
    private final OnnxImporter onnxImporter = new OnnxImporter();
    private final Map<Path, ImportedModel> importedModels = new HashMap<Path, ImportedModel>();

    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof ReferenceNode) {
            return this.transformFeature((ReferenceNode)node, context);
        }
        if (node instanceof CompositeNode) {
            return super.transformChildren((CompositeNode)node, (TransformContext)context);
        }
        return node;
    }

    private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
        if (!feature.getName().equals("onnx")) {
            return feature;
        }
        try {
            OnnxFeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments());
            MLImportFeatureConverter.ModelStore store = new MLImportFeatureConverter.ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
            if (!store.hasStoredModel()) {
                return this.transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles());
            }
            return this.transformFromStoredModel(store, context.rankProfile());
        }
        catch (UncheckedIOException | IllegalArgumentException e) {
            throw new IllegalArgumentException("Could not use Onnx model from " + feature, e);
        }
    }

    private ExpressionNode transformFromOnnxModel(MLImportFeatureConverter.ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles) {
        ImportedModel model = this.importedModels.computeIfAbsent(store.arguments().modelPath(), k -> this.onnxImporter.importModel(store.arguments().modelName(), store.modelDir()));
        return this.transformFromImportedModel(model, store, profile, queryProfiles);
    }

    static class OnnxFeatureArguments
    extends MLImportFeatureConverter.FeatureArguments {
        public OnnxFeatureArguments(Arguments arguments) {
            if (arguments.isEmpty()) {
                throw new IllegalArgumentException("An onnx node must take an argument pointing to the tensorflow model directory under [application]/models");
            }
            if (arguments.expressions().size() > 3) {
                throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
            }
            this.modelPath = Path.fromString((String)this.asString((ExpressionNode)arguments.expressions().get(0)));
            this.output = this.optionalArgument(1, arguments);
            this.signature = Optional.of("default");
        }
    }
}

