/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchdefinition.expressiontransforms;

import com.google.common.base.Joiner;
import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Logger;

public class TensorFlowFeatureConverter
extends ExpressionTransformer<RankProfileTransformContext> {
    private static final Logger log = Logger.getLogger(TensorFlowFeatureConverter.class.getName());
    private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
    private final Map<Path, TensorFlowModel> importedModels = new HashMap<Path, TensorFlowModel>();

    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof ReferenceNode) {
            return this.transformFeature((ReferenceNode)node, context);
        }
        if (node instanceof CompositeNode) {
            return super.transformChildren((CompositeNode)node, (TransformContext)context);
        }
        return node;
    }

    private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
        if (!feature.getName().equals("tensorflow")) {
            return feature;
        }
        try {
            ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
            if (store.hasStoredModel()) {
                return this.transformFromStoredModel(store, context.rankProfile());
            }
            return this.transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles());
        }
        catch (UncheckedIOException | IllegalArgumentException e) {
            throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
        }
    }

    private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles) {
        TensorFlowModel model = this.importedModels.computeIfAbsent(store.arguments().modelPath(), k -> this.tensorFlowImporter.importModel(store.tensorFlowModelDir()));
        TensorFlowModel.Signature signature = this.chooseSignature(model, store.arguments().signature());
        String output = this.chooseOutput(signature, store.arguments().output());
        RankingExpression expression = (RankingExpression)model.expressions().get(output);
        this.verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles);
        store.writeConverted(expression);
        model.smallConstants().forEach((k, v) -> this.transformSmallConstant(store, profile, (String)k, (Tensor)v));
        model.largeConstants().forEach((k, v) -> this.transformLargeConstant(store, profile, (String)k, (Tensor)v));
        return expression.getRoot();
    }

    private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
        for (Pair pair : store.readSmallConstants()) {
            profile.addConstant((String)pair.getFirst(), this.asValue((Tensor)pair.getSecond()));
        }
        for (RankingConstant rankingConstant : store.readLargeConstants()) {
            if (profile.getSearch().getRankingConstants().containsKey(rankingConstant.getName())) continue;
            profile.getSearch().addRankingConstant(rankingConstant);
        }
        return store.readConverted().getRoot();
    }

    private TensorFlowModel.Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) {
        if (!signatureName.isPresent()) {
            if (importResult.signatures().size() == 0) {
                throw new IllegalArgumentException("No signatures are available");
            }
            if (importResult.signatures().size() > 1) {
                throw new IllegalArgumentException("Model has multiple signatures (" + Joiner.on((String)", ").join(importResult.signatures().keySet()) + "), one must be specified as a second argument to tensorflow()");
            }
            return (TensorFlowModel.Signature)importResult.signatures().values().stream().findFirst().get();
        }
        TensorFlowModel.Signature signature = (TensorFlowModel.Signature)importResult.signatures().get(signatureName.get());
        if (signature == null) {
            throw new IllegalArgumentException("Model does not have the specified signature '" + signatureName.get() + "'");
        }
        return signature;
    }

    private String chooseOutput(TensorFlowModel.Signature signature, Optional<String> outputName) {
        if (!outputName.isPresent()) {
            if (signature.outputs().size() == 0) {
                throw new IllegalArgumentException("No outputs are available" + this.skippedOutputsDescription(signature));
            }
            if (signature.outputs().size() > 1) {
                throw new IllegalArgumentException(signature + " has multiple outputs (" + Joiner.on((String)", ").join(signature.outputs().keySet()) + "), one must be specified as a third argument to tensorflow()");
            }
            return (String)signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
        }
        String output = (String)signature.outputs().get(outputName.get());
        if (output == null) {
            if (signature.skippedOutputs().containsKey(outputName.get())) {
                throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + (String)signature.skippedOutputs().get(outputName.get()));
            }
            throw new IllegalArgumentException("Model does not have the specified output '" + outputName.get() + "'");
        }
        return output;
    }

    private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
        store.writeSmallConstant(constantName, constantValue);
        profile.addConstant(constantName, this.asValue(constantValue));
    }

    private void transformLargeConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
        Path constantPath = store.writeLargeConstant(constantName, constantValue);
        if (!profile.getSearch().getRankingConstants().containsKey(constantName)) {
            log.info("Adding constant '" + constantName + "' of type " + constantValue.type());
            profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantPath.toString()));
        }
    }

    private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
        if (signature.skippedOutputs().isEmpty()) {
            return "";
        }
        StringBuilder b = new StringBuilder(": ");
        signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append((String)k).append("': ").append((String)v));
        return b.toString();
    }

    private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros, RankProfile profile, QueryProfileRegistry queryProfiles) {
        ArrayList<String> macroNames = new ArrayList<String>();
        this.addMacroNamesIn(expression.getRoot(), macroNames);
        for (String macroName : macroNames) {
            TensorType requiredType = requiredMacros.get(macroName);
            if (requiredType == null) continue;
            RankProfile.Macro macro = profile.getMacros().get(macroName);
            if (macro == null) {
                throw new IllegalArgumentException("Model refers Placeholder '" + macroName + "' of type " + requiredType + " but this macro is not present in " + profile);
            }
            TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
            if (actualType == null) {
                throw new IllegalArgumentException("Model refers Placeholder '" + macroName + "' of type " + requiredType + " which must be produced by a macro in the rank profile, but this macro references a feature which is not declared");
            }
            if (actualType.isAssignableTo(requiredType)) continue;
            throw new IllegalArgumentException("Model refers Placeholder '" + macroName + "' of type " + requiredType + " which must be produced by a macro in the rank profile, but this macro produces type " + actualType);
        }
    }

    private void addMacroNamesIn(ExpressionNode node, List<String> names) {
        if (node instanceof ReferenceNode) {
            ReferenceNode referenceNode = (ReferenceNode)node;
            if (referenceNode.getOutput() == null) {
                names.add(referenceNode.getName());
            }
        } else if (node instanceof CompositeNode) {
            for (ExpressionNode child : ((CompositeNode)node).children()) {
                this.addMacroNamesIn(child, names);
            }
        }
    }

    private Value asValue(Tensor tensor) {
        if (tensor.type().rank() == 0) {
            return new DoubleValue(tensor.asDouble());
        }
        return new TensorValue(tensor);
    }

    private static class FeatureArguments {
        private final Path modelPath;
        private final Optional<String> signature;
        private final Optional<String> output;

        public FeatureArguments(Arguments arguments) {
            if (arguments.isEmpty()) {
                throw new IllegalArgumentException("A tensorflow node must take an argument pointing to the tensorflow model directory under [application]/models");
            }
            if (arguments.expressions().size() > 3) {
                throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
            }
            this.modelPath = Path.fromString((String)this.asString((ExpressionNode)arguments.expressions().get(0)));
            this.signature = this.optionalArgument(1, arguments);
            this.output = this.optionalArgument(2, arguments);
        }

        public Path modelPath() {
            return this.modelPath;
        }

        public Optional<String> signature() {
            return this.signature;
        }

        public Optional<String> output() {
            return this.output;
        }

        public Path smallConstantsPath() {
            return ApplicationPackage.MODELS_GENERATED_DIR.append(this.modelPath).append("constants.txt");
        }

        public Path largeConstantsPath() {
            return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(this.modelPath).append("constants");
        }

        public Path expressionPath() {
            return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(this.modelPath).append("expressions").append(this.expressionFileName());
        }

        private String expressionFileName() {
            StringBuilder fileName = new StringBuilder();
            this.signature.ifPresent(s -> fileName.append((String)s).append("."));
            this.output.ifPresent(s -> fileName.append((String)s).append("."));
            if (fileName.length() == 0) {
                fileName.append("single.");
            }
            fileName.append("expression");
            return fileName.toString();
        }

        private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
            if (argumentIndex >= arguments.expressions().size()) {
                return Optional.empty();
            }
            return Optional.of(this.asString((ExpressionNode)arguments.expressions().get(argumentIndex)));
        }

        private String asString(ExpressionNode node) {
            if (!(node instanceof ConstantNode)) {
                throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
            }
            return this.stripQuotes(((ConstantNode)node).sourceString());
        }

        private String stripQuotes(String s) {
            if (!this.isQuoteSign(s.codePointAt(0))) {
                return s;
            }
            if (!this.isQuoteSign(s.codePointAt(s.length() - 1))) {
                throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
            }
            return s.substring(1, s.length() - 1);
        }

        private boolean isQuoteSign(int c) {
            return c == 39 || c == 34;
        }
    }

    private static class ModelStore {
        private final ApplicationPackage application;
        private final FeatureArguments arguments;

        public ModelStore(ApplicationPackage application, Arguments arguments) {
            this.application = application;
            this.arguments = new FeatureArguments(arguments);
        }

        public FeatureArguments arguments() {
            return this.arguments;
        }

        public boolean hasStoredModel() {
            try {
                return this.application.getFile(this.arguments.expressionPath()).exists();
            }
            catch (UnsupportedOperationException e) {
                return false;
            }
        }

        public File tensorFlowModelDir() {
            return this.application.getFileReference(ApplicationPackage.MODELS_DIR.append(this.arguments.modelPath()));
        }

        public void writeConverted(RankingExpression expression) {
            this.application.getFile(this.arguments.expressionPath()).writeFile((Reader)new StringReader(expression.getRoot().toString()));
        }

        public RankingExpression readConverted() {
            try {
                return new RankingExpression(this.application.getFile(this.arguments.expressionPath()).createReader());
            }
            catch (IOException e) {
                throw new UncheckedIOException("Could not read " + this.arguments.expressionPath(), e);
            }
            catch (ParseException e) {
                throw new IllegalStateException("Could not parse " + this.arguments.expressionPath(), e);
            }
        }

        public List<RankingConstant> readLargeConstants() {
            try {
                ArrayList<RankingConstant> constants = new ArrayList<RankingConstant>();
                for (ApplicationFile constantFile : this.application.getFile(this.arguments.largeConstantsPath()).listFiles()) {
                    String[] parts = IOUtils.readAll((Reader)constantFile.createReader()).split(":");
                    constants.add(new RankingConstant(parts[0], TensorType.fromSpec((String)parts[1]), parts[2]));
                }
                return constants;
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }

        public Path writeLargeConstant(String name, Tensor constant) {
            Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(this.arguments.modelPath).append("constants");
            Path constantPath = constantsPath.append(name + ".tbf");
            this.application.getFile(this.arguments.largeConstantsPath().append(name + ".constant")).writeFile((Reader)new StringReader(name + ":" + constant.type() + ":" + this.correct(constantPath)));
            this.createIfNeeded(constantsPath);
            IOUtils.writeFile((File)this.application.getFileReference(constantPath), (byte[])TypedBinaryFormat.encode((Tensor)constant));
            return this.correct(constantPath);
        }

        private List<Pair<String, Tensor>> readSmallConstants() {
            try {
                String line;
                ApplicationFile file = this.application.getFile(this.arguments.smallConstantsPath());
                if (!file.exists()) {
                    return Collections.emptyList();
                }
                ArrayList<Pair<String, Tensor>> constants = new ArrayList<Pair<String, Tensor>>();
                BufferedReader reader = new BufferedReader(file.createReader());
                while (null != (line = reader.readLine())) {
                    String[] parts = line.split("\t");
                    String name = parts[0];
                    TensorType type = TensorType.fromSpec((String)parts[1]);
                    Tensor tensor = Tensor.from((TensorType)type, (String)parts[2]);
                    constants.add((Pair<String, Tensor>)new Pair((Object)name, (Object)tensor));
                }
                return constants;
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }

        public void writeSmallConstant(String name, Tensor constant) {
            this.application.getFile(this.arguments.smallConstantsPath()).appendFile(name + "\t" + constant.type().toString() + "\t" + constant.toString() + "\n");
        }

        private Path correct(Path path) {
            if (this.application.getFileReference(Path.fromString((String)"")).getAbsolutePath().endsWith(".preprocessed") && !path.elements().contains(".preprocessed")) {
                return Path.fromString((String)".preprocessed").append(path);
            }
            return path;
        }

        private void createIfNeeded(Path path) {
            File dir = this.application.getFileReference(path);
            if (!dir.exists() && !dir.mkdirs()) {
                throw new IllegalStateException("Could not create " + dir);
            }
        }
    }
}

