/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.handler;

import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.JsonFormat;
import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Optional;
import java.util.concurrent.Executor;

public class ModelsEvaluationHandler
extends ThreadedHttpRequestHandler {
    public static final String API_ROOT = "model-evaluation";
    public static final String VERSION_V1 = "v1";
    public static final String EVALUATE = "eval";
    private final ModelsEvaluator modelsEvaluator;

    public ModelsEvaluationHandler(ModelsEvaluator modelsEvaluator, Executor executor) {
        super(executor);
        this.modelsEvaluator = modelsEvaluator;
    }

    public HttpResponse handle(HttpRequest request) {
        Path path = new Path(request);
        Optional<String> apiName = path.segment(0);
        Optional<String> version = path.segment(1);
        Optional<String> modelName = path.segment(2);
        if (!apiName.isPresent() || !apiName.get().equalsIgnoreCase(API_ROOT)) {
            return new ErrorResponse(404, "unknown API");
        }
        if (!version.isPresent() || !version.get().equalsIgnoreCase(VERSION_V1)) {
            return new ErrorResponse(404, "unknown API version");
        }
        if (!modelName.isPresent()) {
            return this.listAllModels(request);
        }
        if (!this.modelsEvaluator.models().containsKey(modelName.get())) {
            return new ErrorResponse(404, "no model with name '" + modelName.get() + "' found");
        }
        Model model = this.modelsEvaluator.models().get(modelName.get());
        if (path.segments() == 3) {
            if (model.functions().size() > 1) {
                return this.listModelDetails(request, modelName.get());
            }
            return this.listTypeDetails(request, modelName.get());
        }
        if (path.segments() == 4) {
            if (!path.segment(3).get().equalsIgnoreCase(EVALUATE)) {
                return this.listTypeDetails(request, modelName.get(), path.segment(3).get());
            }
            if (model.functions().stream().anyMatch(f -> f.getName().equalsIgnoreCase(EVALUATE))) {
                return this.listTypeDetails(request, modelName.get(), path.segment(3).get());
            }
            if (model.functions().size() <= 1) {
                return this.evaluateModel(request, modelName.get());
            }
            return new ErrorResponse(404, "attempt to evaluate model without specifying function");
        }
        if (path.segments() == 5 && path.segment(4).get().equalsIgnoreCase(EVALUATE)) {
            return this.evaluateModel(request, modelName.get(), path.segment(3).get());
        }
        return new ErrorResponse(404, "unrecognized request");
    }

    private HttpResponse listAllModels(HttpRequest request) {
        Slime slime = new Slime();
        Cursor root = slime.setObject();
        for (String modelName : this.modelsEvaluator.models().keySet()) {
            root.setString(modelName, this.baseUrl(request) + modelName);
        }
        return new Response(200, JsonFormat.toJsonBytes((Slime)slime));
    }

    private HttpResponse listModelDetails(HttpRequest request, String modelName) {
        Model model = this.modelsEvaluator.models().get(modelName);
        Slime slime = new Slime();
        Cursor root = slime.setObject();
        for (ExpressionFunction func : model.functions()) {
            root.setString(func.getName(), this.baseUrl(request) + modelName + "/" + func.getName());
        }
        return new Response(200, JsonFormat.toJsonBytes((Slime)slime));
    }

    private HttpResponse listTypeDetails(HttpRequest request, String modelName) {
        return this.listTypeDetails(request, this.modelsEvaluator.evaluatorOf(modelName, new String[0]));
    }

    private HttpResponse listTypeDetails(HttpRequest request, String modelName, String signatureAndOutput) {
        return this.listTypeDetails(request, this.modelsEvaluator.evaluatorOf(modelName, signatureAndOutput));
    }

    private HttpResponse listTypeDetails(HttpRequest request, FunctionEvaluator evaluator) {
        Slime slime = new Slime();
        Cursor root = slime.setObject();
        Cursor bindings = root.setArray("bindings");
        for (String bindingName : evaluator.context().names()) {
            if (bindingName.startsWith("constant(") || bindingName.startsWith("rankingExpression(")) continue;
            Cursor binding = bindings.addObject();
            binding.setString("name", bindingName);
            binding.setString("type", "");
        }
        return new Response(200, JsonFormat.toJsonBytes((Slime)slime));
    }

    private HttpResponse evaluateModel(HttpRequest request, String modelName) {
        return this.evaluateModel(request, this.modelsEvaluator.evaluatorOf(modelName, new String[0]));
    }

    private HttpResponse evaluateModel(HttpRequest request, String modelName, String signatureAndOutput) {
        return this.evaluateModel(request, this.modelsEvaluator.evaluatorOf(modelName, signatureAndOutput));
    }

    private HttpResponse evaluateModel(HttpRequest request, FunctionEvaluator evaluator) {
        for (String bindingName : evaluator.context().names()) {
            this.property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from((String)s)));
        }
        Tensor result = evaluator.evaluate();
        return new Response(200, com.yahoo.tensor.serialization.JsonFormat.encode((Tensor)result));
    }

    private Optional<String> property(HttpRequest request, String name) {
        return Optional.ofNullable(request.getProperty(name));
    }

    private String baseUrl(HttpRequest request) {
        URI uri = request.getUri();
        StringBuilder sb = new StringBuilder();
        sb.append(uri.getScheme()).append("://").append(uri.getHost());
        if (uri.getPort() >= 0) {
            sb.append(":").append(uri.getPort());
        }
        sb.append("/").append(API_ROOT).append("/").append(VERSION_V1).append("/");
        return sb.toString();
    }

    private static class ErrorResponse
    extends Response {
        ErrorResponse(int code, String data) {
            super(code, "{\"error\":\"" + data + "\"}");
        }
    }

    private static class Response
    extends HttpResponse {
        private final byte[] data;

        Response(int code, byte[] data) {
            super(code);
            this.data = data;
        }

        Response(int code, String data) {
            this(code, data.getBytes(Charset.forName("UTF-8")));
        }

        public String getContentType() {
            return "application/json";
        }

        public void render(OutputStream outputStream) throws IOException {
            outputStream.write(this.data);
        }
    }

    private static class Path {
        private final String[] segments;

        public Path(HttpRequest httpRequest) {
            this.segments = Path.splitPath(httpRequest);
        }

        Optional<String> segment(int index) {
            return index < 0 || index >= this.segments.length ? Optional.empty() : Optional.of(this.segments[index]);
        }

        int segments() {
            return this.segments.length;
        }

        private static String[] splitPath(HttpRequest request) {
            String path = request.getUri().getPath().toLowerCase();
            if (path.startsWith("/")) {
                path = path.substring("/".length());
            }
            if (path.endsWith("/")) {
                path = path.substring(0, path.length() - 1);
            }
            return path.split("/");
        }
    }
}

