/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.rankingexpression.importer;

import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.text.ExpressionFormatter;
import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

public abstract class ModelImporter
implements MlModelImporter {
    private static final Logger log = Logger.getLogger(ModelImporter.class.getName());

    @Override
    public abstract boolean canImport(String var1);

    @Override
    public final ImportedModel importModel(String modelName, File modelPath) {
        return this.importModel(modelName, modelPath.toString());
    }

    public abstract ImportedModel importModel(String var1, String var2);

    protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource, ImportedMlModel.ModelType modelType) {
        ImportedModel model = new ImportedModel(graph.name(), modelSource, modelType);
        log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" + ExpressionFormatter.inTwoColumnMode((int)70, (int)50).format(graph.toFullString()));
        graph.optimize();
        ModelImporter.importSignatures(graph, model);
        ModelImporter.importExpressions(graph, model);
        ModelImporter.reportWarnings(graph, model);
        ModelImporter.logVariableTypes(graph);
        return model;
    }

    private static void importSignatures(IntermediateGraph graph, ImportedModel model) {
        for (String signatureName : graph.signatures()) {
            ImportedModel.Signature signature = model.signature(signatureName);
            for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) {
                signature.input(input.getKey(), input.getValue());
            }
            for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) {
                signature.output(IntermediateOperation.vespaName(output.getKey()), output.getValue());
            }
        }
    }

    private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) {
        for (ImportedModel.Signature signature : model.signatures().values()) {
            for (String outputName : signature.outputs().values()) {
                if (!outputName.equals(operation.name())) continue;
                return true;
            }
        }
        return false;
    }

    private static void importExpressions(IntermediateGraph graph, ImportedModel model) {
        for (ImportedModel.Signature signature : model.signatures().values()) {
            for (String outputName : signature.outputs().values()) {
                try {
                    Optional<TensorFunction> function = ModelImporter.importExpression(graph.get(outputName), model);
                    if (!function.isEmpty()) continue;
                    signature.skippedOutput(outputName, "No valid output function could be found.");
                }
                catch (IllegalArgumentException e) {
                    signature.skippedOutput(outputName, Exceptions.toMessageString((Throwable)e));
                }
            }
        }
    }

    private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
        if (model.expressions().containsKey(operation.name())) {
            return operation.function();
        }
        if (operation.type().isEmpty()) {
            return Optional.empty();
        }
        if (operation.isConstant()) {
            return ModelImporter.importConstant(operation, model);
        }
        ModelImporter.importExpressionInputs(operation, model);
        ModelImporter.importRankingExpression(operation, model);
        ModelImporter.importArgumentExpression(operation, model);
        ModelImporter.importFunctionExpression(operation, model);
        return operation.function();
    }

    private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
        operation.inputs().forEach(input -> ModelImporter.importExpression(input, model));
    }

    private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
        String name = operation.vespaName();
        if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) {
            return operation.function();
        }
        Value value = operation.getConstantValue().orElseThrow(() -> new IllegalArgumentException("Operation '" + operation.vespaName() + "' is constant but does not have a value."));
        if (!(value instanceof TensorValue)) {
            return operation.function();
        }
        Tensor tensor = value.asTensor();
        if (tensor.type().rank() == 0) {
            model.smallConstant(name, tensor);
        } else {
            model.largeConstant(name, tensor);
        }
        return operation.function();
    }

    private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) {
        if (operation.function().isPresent()) {
            String name = operation.name();
            if (!model.expressions().containsKey(name)) {
                OrderedTensorType standardNamingType;
                OrderedTensorType operationType;
                TensorFunction function = operation.function().get();
                if (ModelImporter.isSignatureOutput(model, operation) && !(operationType = operation.type().get()).equals(standardNamingType = OrderedTensorType.standardType(operationType))) {
                    List<String> renameFrom = operationType.dimensionNames();
                    List<String> renameTo = standardNamingType.dimensionNames();
                    function = new Rename(function, renameFrom, renameTo);
                }
                try {
                    model.expression(name, new RankingExpression(name, function.toString()));
                }
                catch (ParseException e) {
                    throw new RuntimeException("Imported function " + function + " cannot be parsed as a ranking expression", e);
                }
            }
        }
    }

    private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) {
        if (operation.isInput()) {
            OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
            model.input(operation.vespaName(), standardNamingConvention.type());
        }
    }

    private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) {
        if (operation.rankingExpressionFunction().isPresent()) {
            TensorFunction function = operation.rankingExpressionFunction().get();
            try {
                model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString()));
            }
            catch (ParseException e) {
                throw new RuntimeException("Model function " + function + " cannot be parsed as a ranking expression", e);
            }
        }
    }

    private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
        for (ImportedModel.Signature signature : model.signatures().values()) {
            for (String outputName : signature.outputs().values()) {
                ModelImporter.reportWarnings(graph.get(outputName), model, new HashSet<String>());
            }
        }
    }

    private static void reportWarnings(IntermediateOperation operation, ImportedModel model, Set<String> processed) {
        if (processed.contains(operation.name())) {
            return;
        }
        for (String string : operation.warnings()) {
        }
        for (IntermediateOperation input : operation.inputs()) {
            ModelImporter.reportWarnings(input, model, processed);
        }
        processed.add(operation.name());
    }

    private static void logVariableTypes(IntermediateGraph graph) {
        for (IntermediateOperation operation : graph.operations().values()) {
            if (!(operation instanceof Constant) || !operation.type().isPresent()) continue;
            log.info("Importing model variable " + operation.name() + " as " + operation.vespaName() + " of type " + operation.type().get());
        }
    }
}

