/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchdefinition.expressiontransforms;

import com.google.common.base.Joiner;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.File;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class TensorFlowFeatureConverter
extends ExpressionTransformer<RankProfileTransformContext> {
    private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
    private final Map<Path, TensorFlowModel> importedModels = new HashMap<Path, TensorFlowModel>();

    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof ReferenceNode) {
            return this.transformFeature((ReferenceNode)node, context);
        }
        if (node instanceof CompositeNode) {
            return super.transformChildren((CompositeNode)node, (TransformContext)context);
        }
        return node;
    }

    private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
        if (!feature.getName().equals("tensorflow")) {
            return feature;
        }
        try {
            ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
            if (store.hasStoredModel()) {
                return this.transformFromStoredModel(store, context.rankProfile());
            }
            return this.transformFromTensorFlowModel(store, context.rankProfile());
        }
        catch (UncheckedIOException | IllegalArgumentException e) {
            throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
        }
    }

    private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile profile) {
        TensorFlowModel model = this.importedModels.computeIfAbsent(store.arguments().modelPath(), k -> this.tensorFlowImporter.importModel(store.tensorFlowModelDir()));
        TensorFlowModel.Signature signature = this.chooseSignature(model, store.arguments().signature());
        String output = this.chooseOutput(signature, store.arguments().output());
        RankingExpression expression = (RankingExpression)model.expressions().get(output);
        store.writeConverted(expression);
        model.constants().forEach((k, v) -> this.transformConstant(store, profile, (String)k, (Tensor)v));
        return expression.getRoot();
    }

    private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
        for (RankingConstant constant : store.readRankingConstants()) {
            if (profile.getSearch().getRankingConstants().containsKey(constant.getName())) continue;
            profile.getSearch().addRankingConstant(constant);
        }
        return store.readConverted().getRoot();
    }

    private TensorFlowModel.Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) {
        if (!signatureName.isPresent()) {
            if (importResult.signatures().size() == 0) {
                throw new IllegalArgumentException("No signatures are available");
            }
            if (importResult.signatures().size() > 1) {
                throw new IllegalArgumentException("Model has multiple signatures (" + Joiner.on((String)", ").join(importResult.signatures().keySet()) + "), one must be specified as a second argument to tensorflow()");
            }
            return (TensorFlowModel.Signature)importResult.signatures().values().stream().findFirst().get();
        }
        TensorFlowModel.Signature signature = (TensorFlowModel.Signature)importResult.signatures().get(signatureName.get());
        if (signature == null) {
            throw new IllegalArgumentException("Model does not have the specified signature '" + signatureName.get() + "'");
        }
        return signature;
    }

    private String chooseOutput(TensorFlowModel.Signature signature, Optional<String> outputName) {
        if (!outputName.isPresent()) {
            if (signature.outputs().size() == 0) {
                throw new IllegalArgumentException("No outputs are available" + this.skippedOutputsDescription(signature));
            }
            if (signature.outputs().size() > 1) {
                throw new IllegalArgumentException(signature + " has multiple outputs (" + Joiner.on((String)", ").join(signature.outputs().keySet()) + "), one must be specified as a third argument to tensorflow()");
            }
            return (String)signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
        }
        String output = (String)signature.outputs().get(outputName.get());
        if (output == null) {
            if (signature.skippedOutputs().containsKey(outputName.get())) {
                throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + (String)signature.skippedOutputs().get(outputName.get()));
            }
            throw new IllegalArgumentException("Model does not have the specified output '" + outputName.get() + "'");
        }
        return output;
    }

    private void transformConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
        Path constantPath = store.writeConstant(constantName, constantValue);
        if (!profile.getSearch().getRankingConstants().containsKey(constantName)) {
            profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantPath.toString()));
        }
    }

    private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
        if (signature.skippedOutputs().isEmpty()) {
            return "";
        }
        StringBuilder b = new StringBuilder(": ");
        signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append((String)k).append("': ").append((String)v));
        return b.toString();
    }

    private static class FeatureArguments {
        private final Path modelPath;
        private final Optional<String> signature;
        private final Optional<String> output;

        public FeatureArguments(Arguments arguments) {
            if (arguments.isEmpty()) {
                throw new IllegalArgumentException("A tensorflow node must take an argument pointing to the tensorflow model directory under [application]/models");
            }
            if (arguments.expressions().size() > 3) {
                throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
            }
            this.modelPath = Path.fromString((String)this.asString((ExpressionNode)arguments.expressions().get(0)));
            this.signature = this.optionalArgument(1, arguments);
            this.output = this.optionalArgument(2, arguments);
        }

        public Path modelPath() {
            return this.modelPath;
        }

        public Optional<String> signature() {
            return this.signature;
        }

        public Optional<String> output() {
            return this.output;
        }

        public Path rankingConstantsPath() {
            return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(this.modelPath).append("constants");
        }

        public Path expressionPath() {
            return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(this.modelPath).append("expressions").append(this.expressionFileName());
        }

        private String expressionFileName() {
            StringBuilder fileName = new StringBuilder();
            this.signature.ifPresent(s -> fileName.append((String)s).append("."));
            this.output.ifPresent(s -> fileName.append((String)s).append("."));
            if (fileName.length() == 0) {
                fileName.append("single.");
            }
            fileName.append("expression");
            return fileName.toString();
        }

        private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
            if (argumentIndex >= arguments.expressions().size()) {
                return Optional.empty();
            }
            return Optional.of(this.asString((ExpressionNode)arguments.expressions().get(argumentIndex)));
        }

        private String asString(ExpressionNode node) {
            if (!(node instanceof ConstantNode)) {
                throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
            }
            return this.stripQuotes(((ConstantNode)node).sourceString());
        }

        private String stripQuotes(String s) {
            if (!this.isQuoteSign(s.codePointAt(0))) {
                return s;
            }
            if (!this.isQuoteSign(s.codePointAt(s.length() - 1))) {
                throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
            }
            return s.substring(1, s.length() - 1);
        }

        private boolean isQuoteSign(int c) {
            return c == 39 || c == 34;
        }
    }

    private static class ModelStore {
        private final ApplicationPackage application;
        private final FeatureArguments arguments;

        public ModelStore(ApplicationPackage application, Arguments arguments) {
            this.application = application;
            this.arguments = new FeatureArguments(arguments);
        }

        public FeatureArguments arguments() {
            return this.arguments;
        }

        public boolean hasStoredModel() {
            try {
                return this.application.getFile(this.arguments.expressionPath()).exists();
            }
            catch (UnsupportedOperationException e) {
                return false;
            }
        }

        public File tensorFlowModelDir() {
            return this.application.getFileReference(ApplicationPackage.MODELS_DIR.append(this.arguments.modelPath()));
        }

        public void writeConverted(RankingExpression expression) {
            this.application.getFile(this.arguments.expressionPath()).writeFile((Reader)new StringReader(expression.getRoot().toString()));
        }

        public RankingExpression readConverted() {
            try {
                return new RankingExpression(this.application.getFile(this.arguments.expressionPath()).createReader());
            }
            catch (IOException e) {
                throw new UncheckedIOException("Could not read " + this.arguments.expressionPath(), e);
            }
            catch (ParseException e) {
                throw new IllegalStateException("Could not parse " + this.arguments.expressionPath(), e);
            }
        }

        public List<RankingConstant> readRankingConstants() {
            try {
                ArrayList<RankingConstant> constants = new ArrayList<RankingConstant>();
                for (ApplicationFile constantFile : this.application.getFile(this.arguments.rankingConstantsPath()).listFiles()) {
                    String[] parts = IOUtils.readAll((Reader)constantFile.createReader()).split(":");
                    constants.add(new RankingConstant(parts[0], TensorType.fromSpec((String)parts[1]), parts[2]));
                }
                return constants;
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }

        public Path writeConstant(String name, Tensor constant) {
            Path constantPath;
            Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(this.arguments.modelPath).append("constants");
            Path constantPathCorrected = constantPath = constantsPath.append(name + ".tbf");
            if (this.application.getFileReference(Path.fromString((String)"")).getAbsolutePath().endsWith(".preprocessed") && !constantPath.elements().contains(".preprocessed")) {
                constantPathCorrected = Path.fromString((String)".preprocessed").append(constantPath);
            }
            this.application.getFile(this.arguments.rankingConstantsPath().append(name + ".constant")).writeFile((Reader)new StringReader(name + ":" + constant.type() + ":" + constantPathCorrected));
            this.createIfNeeded(constantsPath);
            IOUtils.writeFile((File)this.application.getFileReference(constantPath), (byte[])TypedBinaryFormat.encode((Tensor)constant));
            return constantPathCorrected;
        }

        private void createIfNeeded(Path path) {
            File dir = this.application.getFileReference(path);
            if (!dir.exists() && !dir.mkdirs()) {
                throw new IllegalStateException("Could not create " + dir);
            }
        }
    }
}

