/*
 * Decompiled with CFR 0.152.
 */
package io.trino.proxy;

import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonFactoryBuilder;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.JsonToken;
import com.google.common.hash.HashCode;
import com.google.common.hash.HashFunction;
import com.google.common.hash.Hashing;
import com.google.common.net.MediaType;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import com.google.inject.Inject;
import io.airlift.concurrent.Threads;
import io.airlift.jaxrs.AsyncResponseHandler;
import io.airlift.log.Logger;
import io.airlift.units.Duration;
import io.trino.plugin.base.util.JsonUtils;
import io.trino.proxy.ForProxy;
import io.trino.proxy.JsonWebTokenHandler;
import io.trino.proxy.ProxyConfig;
import io.trino.proxy.ProxyException;
import jakarta.annotation.PreDestroy;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.AsyncResponse;
import jakarta.ws.rs.container.Suspended;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.UriBuilder;
import jakarta.ws.rs.core.UriInfo;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Collections;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Headers;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.ResponseBody;

@Path(value="/")
public class ProxyResource {
    private static final Logger log = Logger.get(ProxyResource.class);
    private static final String X509_ATTRIBUTE = "jakarta.servlet.request.X509Certificate";
    private static final Duration ASYNC_TIMEOUT = new Duration(2.0, TimeUnit.MINUTES);
    private static final JsonFactory JSON_FACTORY = ((JsonFactoryBuilder)JsonUtils.jsonFactoryBuilder().disable(JsonFactory.Feature.CANONICALIZE_FIELD_NAMES)).build();
    private final ExecutorService executor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed((String)"proxy-%s"));
    private final OkHttpClient httpClient;
    private final JsonWebTokenHandler jwtHandler;
    private final URI remoteUri;
    private final HashFunction hmac;
    private static final MediaType JSON = MediaType.create((String)"application", (String)"json");

    @Inject
    public ProxyResource(@ForProxy OkHttpClient httpClient, JsonWebTokenHandler jwtHandler, ProxyConfig config) {
        this.httpClient = Objects.requireNonNull(httpClient, "httpClient is null");
        this.jwtHandler = Objects.requireNonNull(jwtHandler, "jwtHandler is null");
        this.remoteUri = Objects.requireNonNull(config.getUri(), "uri is null");
        this.hmac = Hashing.hmacSha256((byte[])ProxyResource.loadSharedSecret(config.getSharedSecretFile()));
    }

    @PreDestroy
    public void shutdown() {
        this.executor.shutdownNow();
    }

    @GET
    @Path(value="/v1/info")
    @Produces(value={"application/json"})
    public void getInfo(@Context HttpServletRequest servletRequest, @Suspended AsyncResponse asyncResponse) {
        Request.Builder request = new Request.Builder().get().url(UriBuilder.fromUri((URI)this.remoteUri).replacePath("/v1/info").build(new Object[0]).toString());
        this.performRequest(servletRequest, asyncResponse, request, response -> ProxyResource.responseWithHeaders(Response.ok((Object)response.getBody()), response));
    }

    @POST
    @Path(value="/v1/statement")
    @Produces(value={"application/json"})
    public void postStatement(String statement, @Context HttpServletRequest servletRequest, @Context UriInfo uriInfo, @Suspended AsyncResponse asyncResponse) {
        Request.Builder request = new Request.Builder().post(RequestBody.create((String)statement, (okhttp3.MediaType)okhttp3.MediaType.parse((String)"application/json"))).url(UriBuilder.fromUri((URI)this.remoteUri).replacePath("/v1/statement").build(new Object[0]).toString());
        this.performRequest(servletRequest, asyncResponse, request, response -> this.buildResponse(uriInfo, (ProxyResponse)response));
    }

    @GET
    @Path(value="/v1/proxy")
    @Produces(value={"application/json"})
    public void getNext(@QueryParam(value="uri") String uri, @QueryParam(value="hmac") String hash, @Context HttpServletRequest servletRequest, @Context UriInfo uriInfo, @Suspended AsyncResponse asyncResponse) {
        if (!this.hmac.hashString((CharSequence)uri, StandardCharsets.UTF_8).equals((Object)HashCode.fromString((String)hash))) {
            throw ProxyResource.badRequest(Response.Status.FORBIDDEN, "Failed to validate HMAC of URI");
        }
        Request.Builder request = new Request.Builder().get().url(uri);
        this.performRequest(servletRequest, asyncResponse, request, response -> this.buildResponse(uriInfo, (ProxyResponse)response));
    }

    @DELETE
    @Path(value="/v1/proxy")
    @Produces(value={"application/json"})
    public void cancelQuery(@QueryParam(value="uri") String uri, @QueryParam(value="hmac") String hash, @Context HttpServletRequest servletRequest, @Suspended AsyncResponse asyncResponse) {
        if (!this.hmac.hashString((CharSequence)uri, StandardCharsets.UTF_8).equals((Object)HashCode.fromString((String)hash))) {
            throw ProxyResource.badRequest(Response.Status.FORBIDDEN, "Failed to validate HMAC of URI");
        }
        Request.Builder request = new Request.Builder().delete().url(uri);
        this.performRequest(servletRequest, asyncResponse, request, response -> ProxyResource.responseWithHeaders(Response.noContent(), response));
    }

    private void performRequest(HttpServletRequest servletRequest, AsyncResponse asyncResponse, Request.Builder requestBuilder, Function<ProxyResponse, Response> responseBuilder) {
        this.setupBearerToken(servletRequest, requestBuilder);
        for (String name : Collections.list(servletRequest.getHeaderNames())) {
            if (ProxyResource.isTrinoHeader(name) || name.equalsIgnoreCase("Cookie")) {
                for (String value : Collections.list(servletRequest.getHeaders(name))) {
                    requestBuilder.addHeader(name, value);
                }
                continue;
            }
            if (!name.equalsIgnoreCase("User-Agent")) continue;
            for (String value : Collections.list(servletRequest.getHeaders(name))) {
                requestBuilder.addHeader(name, "[Trino Proxy] " + value);
            }
        }
        Request request = requestBuilder.build();
        FluentFuture future = this.executeHttp(request).transform(responseBuilder::apply, (Executor)this.executor).catching(ProxyException.class, e -> (Response)ProxyResource.handleProxyException(request, e), MoreExecutors.directExecutor());
        this.setupAsyncResponse(asyncResponse, (ListenableFuture<Response>)future);
    }

    private Response buildResponse(UriInfo uriInfo, ProxyResponse response) {
        byte[] body = ProxyResource.rewriteResponse(response.getBody(), uri -> this.rewriteUri(uriInfo, (String)uri));
        return ProxyResource.responseWithHeaders(Response.ok((Object)body), response);
    }

    private String rewriteUri(UriInfo uriInfo, String uri) {
        return uriInfo.getAbsolutePathBuilder().replacePath("/v1/proxy").queryParam("uri", new Object[]{uri}).queryParam("hmac", new Object[]{this.hmac.hashString((CharSequence)uri, StandardCharsets.UTF_8)}).build(new Object[0]).toString();
    }

    private void setupAsyncResponse(AsyncResponse asyncResponse, ListenableFuture<Response> future) {
        AsyncResponseHandler.bindAsyncResponse((AsyncResponse)asyncResponse, future, (Executor)this.executor).withTimeout(ASYNC_TIMEOUT, () -> Response.status((Response.Status)Response.Status.BAD_GATEWAY).type(jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE).entity((Object)("Request to remote Trino server timed out after" + String.valueOf(ASYNC_TIMEOUT))).build());
    }

    private FluentFuture<ProxyResponse> executeHttp(Request request) {
        final SettableFuture future = SettableFuture.create();
        this.httpClient.newCall(request).enqueue(new Callback(){

            public void onFailure(Call call, IOException e) {
                future.setException((Throwable)e);
            }

            public void onResponse(Call call, okhttp3.Response response) {
                if (response.code() == Response.Status.NO_CONTENT.getStatusCode()) {
                    future.set((Object)new ProxyResponse(response.headers(), new byte[0]));
                    return;
                }
                if (response.code() != Response.Status.OK.getStatusCode()) {
                    try (ResponseBody body = response.body();){
                        future.setException((Throwable)new ProxyException(String.format("Bad status code from remote Trino server: %s: %s", response.code(), body.string())));
                        return;
                    }
                    catch (IOException e) {
                        future.setException((Throwable)e);
                        return;
                    }
                }
                String contentType = response.header("Content-Type");
                if (contentType == null) {
                    throw new ProxyException("No Content-Type set in response from remote Trino server");
                }
                if (!MediaType.parse((String)contentType).is(JSON)) {
                    throw new ProxyException("Bad Content-Type from remote Trino server:" + contentType);
                }
                try (ResponseBody body = response.body();){
                    future.set((Object)new ProxyResponse(response.headers(), body.bytes()));
                    return;
                }
                catch (IOException e) {
                    throw new ProxyException("Failed reading response from remote Trino server", e);
                }
            }
        });
        return FluentFuture.from((ListenableFuture)future);
    }

    private void setupBearerToken(HttpServletRequest servletRequest, Request.Builder requestBuilder) {
        if (!this.jwtHandler.isConfigured()) {
            return;
        }
        X509Certificate[] certs = (X509Certificate[])servletRequest.getAttribute(X509_ATTRIBUTE);
        if (certs == null || certs.length == 0) {
            throw ProxyResource.badRequest(Response.Status.FORBIDDEN, "No TLS certificate present for request");
        }
        String principal = certs[0].getSubjectX500Principal().getName();
        String accessToken = this.jwtHandler.getBearerToken(principal);
        requestBuilder.addHeader("Authorization", "Bearer " + accessToken);
    }

    private static <T> T handleProxyException(Request request, ProxyException e) {
        log.warn((Throwable)e, "Proxy request failed: %s %s", new Object[]{request.method(), request.url()});
        throw ProxyResource.badRequest(Response.Status.BAD_GATEWAY, e.getMessage());
    }

    private static WebApplicationException badRequest(Response.Status status, String message) {
        throw new WebApplicationException(Response.status((Response.Status)status).type(jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE).entity((Object)message).build());
    }

    private static boolean isTrinoHeader(String name) {
        return name.toLowerCase(Locale.ENGLISH).startsWith("x-trino-");
    }

    private static Response responseWithHeaders(Response.ResponseBuilder builder, ProxyResponse response) {
        response.getHeaders().names().forEach(name -> {
            if (ProxyResource.isTrinoHeader(name) || name.equalsIgnoreCase("Set-Cookie")) {
                builder.header(name, (Object)response.getHeaders().get(name));
            }
        });
        return builder.build();
    }

    private static byte[] rewriteResponse(byte[] input, Function<String, String> uriRewriter) {
        try {
            JsonToken token;
            JsonGenerator generator;
            ByteArrayOutputStream out;
            JsonParser parser;
            block8: {
                parser = JSON_FACTORY.createParser(input);
                out = new ByteArrayOutputStream(input.length * 2);
                generator = JSON_FACTORY.createGenerator((OutputStream)out);
                token = parser.nextToken();
                if (token != JsonToken.START_OBJECT) {
                    throw ProxyResource.invalidJson("bad start token: " + String.valueOf(token));
                }
                generator.copyCurrentEvent(parser);
                while (true) {
                    if ((token = parser.nextToken()) == null) {
                        throw ProxyResource.invalidJson("unexpected end of stream");
                    }
                    if (token == JsonToken.END_OBJECT) break block8;
                    if (token != JsonToken.FIELD_NAME) break;
                    String name = parser.getValueAsString();
                    if (!"nextUri".equals(name) && !"partialCancelUri".equals(name)) {
                        generator.copyCurrentStructure(parser);
                        continue;
                    }
                    token = parser.nextToken();
                    if (token != JsonToken.VALUE_STRING) {
                        throw ProxyResource.invalidJson(String.format("bad %s token: %s", name, token));
                    }
                    String value = parser.getValueAsString();
                    value = uriRewriter.apply(value);
                    generator.writeStringField(name, value);
                }
                throw ProxyResource.invalidJson("unexpected token: " + String.valueOf(token));
            }
            generator.copyCurrentEvent(parser);
            token = parser.nextToken();
            if (token != null) {
                throw ProxyResource.invalidJson("unexpected token after object close: " + String.valueOf(token));
            }
            generator.close();
            return out.toByteArray();
        }
        catch (IOException e) {
            throw new ProxyException(e);
        }
    }

    private static IOException invalidJson(String message) {
        return new IOException("Invalid JSON response from remote Trino server: " + message);
    }

    private static byte[] loadSharedSecret(File file) {
        try {
            return Base64.getMimeDecoder().decode(Files.readAllBytes(file.toPath()));
        }
        catch (IOException | IllegalArgumentException e) {
            throw new RuntimeException("Failed to load shared secret file: " + String.valueOf(file), e);
        }
    }

    public static class ProxyResponse {
        private final Headers headers;
        private final byte[] body;

        ProxyResponse(Headers headers, byte[] body) {
            this.headers = Objects.requireNonNull(headers, "headers is null");
            this.body = Objects.requireNonNull(body, "body is null");
        }

        public Headers getHeaders() {
            return this.headers;
        }

        public byte[] getBody() {
            return this.body;
        }
    }
}

