/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.handler;

import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.JsonFormat;
import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor;

public class ModelsEvaluationHandler
extends ThreadedHttpRequestHandler {
    private static final String missingValueKey = "missing-value";
    public static final String API_ROOT = "model-evaluation";
    public static final String VERSION_V1 = "v1";
    public static final String EVALUATE = "eval";
    private final ModelsEvaluator modelsEvaluator;

    public ModelsEvaluationHandler(ModelsEvaluator modelsEvaluator, Executor executor) {
        super(executor);
        this.modelsEvaluator = modelsEvaluator;
    }

    public HttpResponse handle(HttpRequest request) {
        Path path = new Path(request);
        Optional<String> apiName = path.segment(0);
        Optional<String> version = path.segment(1);
        Optional<String> modelName = path.segment(2);
        try {
            if (apiName.isEmpty() || !apiName.get().equalsIgnoreCase(API_ROOT)) {
                throw new IllegalArgumentException("unknown API");
            }
            if (version.isEmpty() || !version.get().equalsIgnoreCase(VERSION_V1)) {
                throw new IllegalArgumentException("unknown API version");
            }
            if (modelName.isEmpty()) {
                return this.listAllModels(request);
            }
            Model model = this.modelsEvaluator.requireModel(modelName.get());
            Optional<Integer> evalSegment = path.lastIndexOf(EVALUATE);
            String[] function = path.range(3, evalSegment);
            if (evalSegment.isPresent()) {
                return this.evaluateModel(request, model, function);
            }
            return this.listModelInformation(request, model, function);
        }
        catch (IllegalArgumentException e) {
            return new ErrorResponse(404, Exceptions.toMessageString((Throwable)e));
        }
        catch (IllegalStateException e) {
            return new ErrorResponse(400, Exceptions.toMessageString((Throwable)e));
        }
    }

    private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) {
        FunctionEvaluator evaluator = model.evaluatorOf(function);
        this.property(request, missingValueKey).ifPresent(missingValue -> evaluator.setMissingValue(Tensor.from((String)missingValue)));
        for (Map.Entry argument : evaluator.function().argumentTypes().entrySet()) {
            Optional<String> value = this.property(request, (String)argument.getKey());
            if (!value.isPresent()) continue;
            try {
                evaluator.bind((String)argument.getKey(), Tensor.from((TensorType)((TensorType)argument.getValue()), (String)value.get()));
            }
            catch (IllegalArgumentException e) {
                evaluator.bind((String)argument.getKey(), value.get());
            }
        }
        Tensor result = evaluator.evaluate();
        return new Response(200, com.yahoo.tensor.serialization.JsonFormat.encode((Tensor)result));
    }

    private HttpResponse listAllModels(HttpRequest request) {
        Slime slime = new Slime();
        Cursor root = slime.setObject();
        for (String modelName : this.modelsEvaluator.models().keySet()) {
            root.setString(modelName, this.baseUrl(request) + modelName);
        }
        return new Response(200, JsonFormat.toJsonBytes((Slime)slime));
    }

    private HttpResponse listModelInformation(HttpRequest request, Model model, String[] function) {
        Slime slime = new Slime();
        Cursor root = slime.setObject();
        root.setString("model", model.name());
        if (function.length == 0) {
            this.listFunctions(request, model, root);
        } else {
            this.listFunctionDetails(request, model, function, root);
        }
        return new Response(200, JsonFormat.toJsonBytes((Slime)slime));
    }

    private void listFunctions(HttpRequest request, Model model, Cursor cursor) {
        Cursor functions = cursor.setArray("functions");
        for (ExpressionFunction func : model.functions()) {
            Cursor function = functions.addObject();
            this.listFunctionDetails(request, model, new String[]{func.getName()}, function);
        }
    }

    private void listFunctionDetails(HttpRequest request, Model model, String[] function, Cursor cursor) {
        String compactedFunction = String.join((CharSequence)".", function);
        FunctionEvaluator evaluator = model.evaluatorOf(function);
        cursor.setString("function", compactedFunction);
        cursor.setString("info", this.baseUrl(request) + model.name() + "/" + compactedFunction);
        cursor.setString(EVALUATE, this.baseUrl(request) + model.name() + "/" + compactedFunction + "/eval");
        Cursor bindings = cursor.setArray("arguments");
        for (Map.Entry argument : evaluator.function().argumentTypes().entrySet()) {
            Cursor binding = bindings.addObject();
            binding.setString("name", (String)argument.getKey());
            binding.setString("type", ((TensorType)argument.getValue()).toString());
        }
    }

    private Optional<String> property(HttpRequest request, String name) {
        return Optional.ofNullable(request.getProperty(name));
    }

    private String baseUrl(HttpRequest request) {
        URI uri = request.getUri();
        StringBuilder sb = new StringBuilder();
        sb.append(uri.getScheme()).append("://");
        if (request.getHeader("Host") != null) {
            sb.append(request.getHeader("Host"));
        } else {
            sb.append(uri.getHost());
            if (uri.getPort() >= 0) {
                sb.append(":").append(uri.getPort());
            }
        }
        sb.append("/").append(API_ROOT).append("/").append(VERSION_V1).append("/");
        return sb.toString();
    }

    private static class ErrorResponse
    extends Response {
        ErrorResponse(int code, String data) {
            super(code, "{\"error\":\"" + data + "\"}");
        }
    }

    private static class Response
    extends HttpResponse {
        private final byte[] data;

        Response(int code, byte[] data) {
            super(code);
            this.data = data;
        }

        Response(int code, String data) {
            this(code, data.getBytes(Charset.forName("UTF-8")));
        }

        public String getContentType() {
            return "application/json";
        }

        public void render(OutputStream outputStream) throws IOException {
            outputStream.write(this.data);
        }
    }

    private static class Path {
        private final String[] segments;

        public Path(HttpRequest httpRequest) {
            this.segments = Path.splitPath(httpRequest);
        }

        Optional<String> segment(int index) {
            return index < 0 || index >= this.segments.length ? Optional.empty() : Optional.of(this.segments[index]);
        }

        Optional<Integer> lastIndexOf(String segment) {
            for (int i = this.segments.length - 1; i >= 0; --i) {
                if (!this.segments[i].equalsIgnoreCase(segment)) continue;
                return Optional.of(i);
            }
            return Optional.empty();
        }

        public String[] range(int start, Optional<Integer> end) {
            return Arrays.copyOfRange(this.segments, start, end.isPresent() ? end.get() : this.segments.length);
        }

        private static String[] splitPath(HttpRequest request) {
            String path = request.getUri().getPath().toLowerCase();
            if (path.startsWith("/")) {
                path = path.substring("/".length());
            }
            if (path.endsWith("/")) {
                path = path.substring(0, path.length() - 1);
            }
            return path.split("/");
        }
    }
}

