/*
 * Decompiled with CFR 0.152.
 */
package com.databricks.jdbc.auth;

import com.databricks.internal.apache.http.client.entity.UrlEncodedFormEntity;
import com.databricks.internal.apache.http.client.methods.CloseableHttpResponse;
import com.databricks.internal.apache.http.client.methods.HttpGet;
import com.databricks.internal.apache.http.client.methods.HttpPost;
import com.databricks.internal.apache.http.client.utils.URIBuilder;
import com.databricks.internal.apache.http.message.BasicNameValuePair;
import com.databricks.internal.apache.http.util.EntityUtils;
import com.databricks.internal.fasterxml.jackson.databind.JsonNode;
import com.databricks.internal.fasterxml.jackson.databind.ObjectMapper;
import com.databricks.internal.nimbusds.jwt.JWTClaimsSet;
import com.databricks.internal.nimbusds.jwt.SignedJWT;
import com.databricks.internal.sdk.core.CredentialsProvider;
import com.databricks.internal.sdk.core.DatabricksConfig;
import com.databricks.internal.sdk.core.HeaderFactory;
import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
import com.databricks.jdbc.dbclient.IDatabricksHttpClient;
import com.databricks.jdbc.dbclient.impl.http.DatabricksHttpClientFactory;
import com.databricks.jdbc.exception.DatabricksDriverException;
import com.databricks.jdbc.exception.DatabricksParsingException;
import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.jdbc.model.telemetry.enums.DatabricksDriverErrorCode;
import java.awt.Desktop;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.lang.invoke.CallSite;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URI;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.text.ParseException;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

public class AzureExternalBrowserProvider
implements CredentialsProvider {
    private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(AzureExternalBrowserProvider.class);
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private static final String DEFAULT_SCOPE = "offline_access sql";
    private static final String CODE_CHALLENGE_METHOD = "S256";
    private static final String GRANT_TYPE_AUTHORIZATION_CODE = "authorization_code";
    private static final String GRANT_TYPE_REFRESH_TOKEN = "refresh_token";
    private static final String CONTENT_TYPE_FORM_URLENCODED = "application/x-www-form-urlencoded";
    private static final String ACCEPT_JSON = "application/json";
    private static final List<Integer> DEFAULT_PORT_RANGE = List.of(Integer.valueOf(8020), Integer.valueOf(8021), Integer.valueOf(8022), Integer.valueOf(8023), Integer.valueOf(8024));
    private final String hostname;
    private final String clientId;
    private final IDatabricksConnectionContext connectionContext;
    private final IDatabricksHttpClient httpClient;
    private final OAuthCallbackServer callbackServer;
    private final int callbackPort;
    private String accessToken;
    private String refreshToken;
    private LocalDateTime tokenExpiry;
    private OAuthConfig oauthConfig;

    public AzureExternalBrowserProvider(IDatabricksConnectionContext connectionContext, int availablePort) throws DatabricksParsingException {
        this.connectionContext = connectionContext;
        this.hostname = connectionContext.getHost();
        this.clientId = connectionContext.getClientId();
        this.httpClient = DatabricksHttpClientFactory.getInstance().getClient(connectionContext);
        this.callbackServer = new OAuthCallbackServer();
        this.callbackPort = availablePort;
    }

    @Override
    public String authType() {
        return "azure-oauth-u2m";
    }

    @Override
    public HeaderFactory configure(DatabricksConfig databricksConfig) {
        return () -> {
            this.ensureValidTokens();
            HashMap<String, CallSite> headers = new HashMap<String, CallSite>();
            headers.put("Authorization", (CallSite)((Object)("Bearer " + this.accessToken)));
            return headers;
        };
    }

    private void ensureValidTokens() {
        if (this.accessToken == null || this.isTokenExpired(this.accessToken)) {
            try {
                if (this.refreshToken != null) {
                    try {
                        this.refreshAccessToken();
                        return;
                    }
                    catch (Exception e) {
                        LOGGER.warn("Token refresh failed, re-authenticating: {}", e.getMessage());
                    }
                }
                this.performOAuthFlow();
            }
            catch (Exception e) {
                LOGGER.error(e, "Failed to obtain OAuth tokens");
                throw new DatabricksDriverException("Failed to obtain OAuth tokens", (Throwable)e, DatabricksDriverErrorCode.AUTH_ERROR);
            }
        }
    }

    private void refreshAccessToken() throws Exception {
        if (this.refreshToken == null) {
            throw new DatabricksDriverException("No refresh token available", DatabricksDriverErrorCode.AUTH_ERROR);
        }
        if (this.oauthConfig == null) {
            this.oauthConfig = this.fetchOAuthConfig();
        }
        HttpPost request = new HttpPost(this.oauthConfig.getTokenEndpoint());
        request.setHeader("Content-Type", CONTENT_TYPE_FORM_URLENCODED);
        request.setHeader("Accept", ACCEPT_JSON);
        List<BasicNameValuePair> params = List.of(new BasicNameValuePair("grant_type", GRANT_TYPE_REFRESH_TOKEN), new BasicNameValuePair("client_id", this.clientId), new BasicNameValuePair(GRANT_TYPE_REFRESH_TOKEN, this.refreshToken));
        request.setEntity(new UrlEncodedFormEntity(params));
        try (CloseableHttpResponse response = this.httpClient.execute(request);){
            String responseBody = EntityUtils.toString(response.getEntity());
            JsonNode tokenResponse = OBJECT_MAPPER.readTree(responseBody);
            if (tokenResponse.has("error")) {
                String error = tokenResponse.get("error").asText();
                throw new DatabricksDriverException("Token refresh failed: " + error, DatabricksDriverErrorCode.AUTH_ERROR);
            }
            this.accessToken = tokenResponse.get("access_token").asText();
            if (tokenResponse.has(GRANT_TYPE_REFRESH_TOKEN)) {
                this.refreshToken = tokenResponse.get(GRANT_TYPE_REFRESH_TOKEN).asText();
            }
            if (tokenResponse.has("expires_in")) {
                int expiresIn = tokenResponse.get("expires_in").asInt();
                this.tokenExpiry = LocalDateTime.now().plusSeconds(expiresIn);
            } else {
                this.tokenExpiry = this.parseTokenExpiration(this.accessToken);
            }
            LOGGER.debug("Successfully refreshed OAuth tokens");
        }
        catch (Exception e) {
            LOGGER.error(e, "Token refresh failed");
            throw new DatabricksDriverException("Failed to refresh access token", (Throwable)e, DatabricksDriverErrorCode.AUTH_ERROR);
        }
    }

    private void performOAuthFlow() throws Exception {
        this.oauthConfig = this.fetchOAuthConfig();
        PKCEChallenge pkce = this.generatePKCEChallenge();
        String redirectUri = "http://localhost:" + this.callbackPort;
        String state = this.generateRandomString(32);
        String authUrl = this.buildAuthorizationUrl(pkce, redirectUri, state);
        LOGGER.debug("Starting OAuth callback server on port {}", this.callbackPort);
        this.callbackServer.start(this.callbackPort);
        try {
            Thread.sleep(200L);
            this.openBrowser(authUrl);
            OAuthCallback callback = this.callbackServer.waitForCallback(300L, TimeUnit.SECONDS);
            if (!state.equals(callback.getState())) {
                String error = String.format("OAuth state parameter mismatch. Expected: %s, Received: %s", state, callback.getState());
                LOGGER.error(error);
                throw new DatabricksDriverException(error, DatabricksDriverErrorCode.AUTH_ERROR);
            }
            this.exchangeCodeForTokens(callback.getCode(), pkce.getVerifier(), redirectUri);
        }
        catch (Exception e) {
            String errorMessage = String.format("OAuth flow failed for Azure U2M. Error: %s", e);
            LOGGER.error(e, errorMessage);
            throw e;
        }
        finally {
            LOGGER.debug("Stopping OAuth callback server");
            this.callbackServer.stop();
        }
    }

    private OAuthConfig fetchOAuthConfig() {
        OAuthConfig oAuthConfig;
        block8: {
            String configUrl = "https://" + this.hostname + "/oidc/.well-known/oauth-authorization-server";
            LOGGER.debug("Fetching OAuth configuration from: {}", configUrl);
            HttpGet request = new HttpGet(configUrl);
            request.setHeader("Accept", ACCEPT_JSON);
            CloseableHttpResponse response = this.httpClient.execute(request);
            try {
                String responseBody = EntityUtils.toString(response.getEntity());
                JsonNode config = OBJECT_MAPPER.readTree(responseBody);
                String authorizationEndpoint = config.get("authorization_endpoint").asText();
                String tokenEndpoint = config.get("token_endpoint").asText();
                String issuer = config.get("issuer").asText();
                this.validateOAuthConfig(authorizationEndpoint, tokenEndpoint, issuer);
                oAuthConfig = new OAuthConfig(authorizationEndpoint, tokenEndpoint, issuer);
                if (response == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (response != null) {
                        try {
                            response.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    LOGGER.error(e, "Failed to fetch OAuth config from {}", configUrl);
                    throw new DatabricksDriverException("Unable to fetch OAuth configuration", (Throwable)e, DatabricksDriverErrorCode.AUTH_ERROR);
                }
            }
            response.close();
        }
        return oAuthConfig;
    }

    private void validateOAuthConfig(String authorizationEndpoint, String tokenEndpoint, String issuer) {
        if (authorizationEndpoint == null || authorizationEndpoint.trim().isEmpty()) {
            throw new DatabricksDriverException("OAuth configuration is missing authorization endpoint", DatabricksDriverErrorCode.AUTH_ERROR);
        }
        if (tokenEndpoint == null || tokenEndpoint.trim().isEmpty()) {
            throw new DatabricksDriverException("OAuth configuration is missing token endpoint", DatabricksDriverErrorCode.AUTH_ERROR);
        }
        if (issuer == null || issuer.trim().isEmpty()) {
            throw new DatabricksDriverException("OAuth configuration is missing issuer", DatabricksDriverErrorCode.AUTH_ERROR);
        }
    }

    private PKCEChallenge generatePKCEChallenge() throws NoSuchAlgorithmException {
        SecureRandom random = new SecureRandom();
        byte[] verifierBytes = new byte[32];
        random.nextBytes(verifierBytes);
        String verifier = Base64.getUrlEncoder().withoutPadding().encodeToString(verifierBytes);
        MessageDigest digest = MessageDigest.getInstance("SHA-256");
        byte[] challengeBytes = digest.digest(verifier.getBytes(StandardCharsets.UTF_8));
        String challenge = Base64.getUrlEncoder().withoutPadding().encodeToString(challengeBytes);
        return new PKCEChallenge(verifier, challenge);
    }

    private String buildAuthorizationUrl(PKCEChallenge pkce, String redirectUri, String state) throws Exception {
        URIBuilder builder = new URIBuilder(this.oauthConfig.getAuthorizationEndpoint());
        builder.addParameter("response_type", "code");
        builder.addParameter("client_id", this.clientId);
        builder.addParameter("scope", DEFAULT_SCOPE);
        builder.addParameter("redirect_uri", redirectUri);
        builder.addParameter("state", state);
        builder.addParameter("code_challenge", pkce.getChallenge());
        builder.addParameter("code_challenge_method", CODE_CHALLENGE_METHOD);
        return builder.build().toString();
    }

    private void openBrowser(String authUrl) {
        LOGGER.debug("If the browser doesn't open automatically, please manually navigate to: {}", authUrl);
        if (Desktop.isDesktopSupported() && Desktop.getDesktop().isSupported(Desktop.Action.BROWSE)) {
            try {
                Desktop.getDesktop().browse(new URI(authUrl));
            }
            catch (Exception e) {
                LOGGER.warn("Failed to open browser automatically: {}. Please manually open your browser and navigate to: {}", e.getMessage(), authUrl);
            }
        } else {
            LOGGER.warn("Desktop browsing not supported on this platform. Please manually open your browser and navigate to: {}", authUrl);
        }
    }

    private void exchangeCodeForTokens(String code, String codeVerifier, String redirectUri) throws Exception {
        HttpPost request = new HttpPost(this.oauthConfig.getTokenEndpoint());
        request.setHeader("Content-Type", CONTENT_TYPE_FORM_URLENCODED);
        request.setHeader("Accept", ACCEPT_JSON);
        List<BasicNameValuePair> params = List.of(new BasicNameValuePair("grant_type", GRANT_TYPE_AUTHORIZATION_CODE), new BasicNameValuePair("client_id", this.clientId), new BasicNameValuePair("code", code), new BasicNameValuePair("redirect_uri", redirectUri), new BasicNameValuePair("code_verifier", codeVerifier));
        request.setEntity(new UrlEncodedFormEntity(params));
        try (CloseableHttpResponse response = this.httpClient.execute(request);){
            String responseBody = EntityUtils.toString(response.getEntity());
            JsonNode tokenResponse = OBJECT_MAPPER.readTree(responseBody);
            if (tokenResponse.has("error")) {
                String error = tokenResponse.get("error").asText();
                String errorDescription = tokenResponse.has("error_description") ? tokenResponse.get("error_description").asText() : error;
                throw new DatabricksDriverException("OAuth token error: " + errorDescription, DatabricksDriverErrorCode.AUTH_ERROR);
            }
            this.accessToken = tokenResponse.get("access_token").asText();
            String string = this.refreshToken = tokenResponse.has(GRANT_TYPE_REFRESH_TOKEN) ? tokenResponse.get(GRANT_TYPE_REFRESH_TOKEN).asText() : null;
            if (tokenResponse.has("expires_in")) {
                int expiresIn = tokenResponse.get("expires_in").asInt();
                this.tokenExpiry = LocalDateTime.now().plusSeconds(expiresIn);
            } else {
                this.tokenExpiry = this.parseTokenExpiration(this.accessToken);
            }
        }
        catch (Exception e) {
            LOGGER.error(e, "Token exchange failed");
            throw new DatabricksDriverException("Failed to exchange code for tokens", (Throwable)e, DatabricksDriverErrorCode.AUTH_ERROR);
        }
    }

    private boolean isTokenExpired(String token) {
        try {
            LocalDateTime expiration = this.parseTokenExpiration(token);
            return LocalDateTime.now().isAfter(expiration);
        }
        catch (Exception e) {
            LOGGER.warn("Could not parse token expiration: {}", e.getMessage());
            return true;
        }
    }

    private LocalDateTime parseTokenExpiration(String token) throws ParseException {
        SignedJWT signedJWT = SignedJWT.parse(token);
        JWTClaimsSet claims = signedJWT.getJWTClaimsSet();
        if (claims.getExpirationTime() == null) {
            throw new DatabricksDriverException("Token has no expiration time", DatabricksDriverErrorCode.AUTH_ERROR);
        }
        Instant expirationTime = Instant.ofEpochMilli(claims.getExpirationTime().getTime());
        return expirationTime.atZone(ZoneId.systemDefault()).toLocalDateTime();
    }

    private String generateRandomString(int length) {
        SecureRandom random = new SecureRandom();
        byte[] bytes = new byte[length];
        random.nextBytes(bytes);
        return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
    }

    private static class OAuthCallbackServer {
        private ServerSocket serverSocket;
        private CompletableFuture<OAuthCallback> callbackFuture;
        private volatile boolean isReady = false;
        private final Object serverLock = new Object();

        private OAuthCallbackServer() {
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void start(int port) throws IOException {
            Object object = this.serverLock;
            synchronized (object) {
                this.serverSocket = new ServerSocket(port);
                this.serverSocket.setReuseAddress(true);
                this.serverSocket.setSoTimeout(300000);
                this.callbackFuture = new CompletableFuture();
                Thread serverThread = new Thread(() -> {
                    block5: {
                        try {
                            Object object = this.serverLock;
                            synchronized (object) {
                                this.isReady = true;
                                this.serverLock.notifyAll();
                            }
                            LOGGER.debug("OAuth callback server started and ready on port {}", port);
                            Socket clientSocket = this.serverSocket.accept();
                            LOGGER.debug("OAuth callback connection accepted from {}", clientSocket.getInetAddress());
                            this.handleCallback(clientSocket);
                        }
                        catch (IOException e) {
                            if (this.serverSocket.isClosed()) break block5;
                            LOGGER.error(e, "Error handling OAuth callback");
                            this.callbackFuture.completeExceptionally(e);
                        }
                    }
                });
                serverThread.setDaemon(true);
                serverThread.start();
                Object object2 = this.serverLock;
                synchronized (object2) {
                    while (!this.isReady) {
                        try {
                            this.serverLock.wait(1000L);
                        }
                        catch (InterruptedException e) {
                            Thread.currentThread().interrupt();
                            throw new IOException("Server startup interrupted", e);
                        }
                    }
                }
                try {
                    Thread.sleep(200L);
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new IOException("Server startup interrupted", e);
                }
                LOGGER.debug("OAuth callback server is ready and listening on port {}", port);
            }
        }

        /*
         * Enabled aggressive block sorting
         * Enabled unnecessary exception pruning
         * Enabled aggressive exception aggregation
         */
        private void handleCallback(Socket clientSocket) throws IOException {
            try (BufferedReader reader = new BufferedReader(new InputStreamReader(clientSocket.getInputStream()));
                 PrintWriter writer = new PrintWriter(clientSocket.getOutputStream());){
                String requestLine = reader.readLine();
                if (requestLine == null) {
                    LOGGER.warn("Received empty request from OAuth callback");
                    return;
                }
                LOGGER.debug("Received OAuth callback request: {}", requestLine);
                String[] parts = requestLine.split(" ");
                if (parts.length < 2) {
                    LOGGER.warn("Invalid OAuth callback request format: {}", requestLine);
                    this.sendErrorResponse(writer, "Invalid request");
                    return;
                }
                String query = parts[1];
                if (query.startsWith("/?")) {
                    query = query.substring(2);
                } else if (query.equals("/")) {
                    LOGGER.debug("Received root path request, sending success page");
                    this.sendSuccessResponse(writer);
                    return;
                }
                Map<String, String> params = this.parseQueryString(query);
                String code = params.get("code");
                String state = params.get("state");
                String error = params.get("error");
                LOGGER.debug("OAuth callback parameters - code: {}, state: {}, error: {}", code != null ? "present" : "null", state != null ? "present" : "null", error);
                this.sendSuccessResponse(writer);
                if (error != null) {
                    LOGGER.error("OAuth error received: {}", error);
                    this.callbackFuture.completeExceptionally(new DatabricksDriverException("OAuth error: " + error, DatabricksDriverErrorCode.AUTH_ERROR));
                    return;
                }
                if (code != null) {
                    LOGGER.debug("OAuth authorization code received successfully");
                    this.callbackFuture.complete(new OAuthCallback(code, state));
                    return;
                }
                LOGGER.error("No authorization code received in OAuth callback");
                this.callbackFuture.completeExceptionally(new DatabricksDriverException("No authorization code received", DatabricksDriverErrorCode.AUTH_ERROR));
                return;
            }
        }

        private void sendSuccessResponse(PrintWriter writer) {
            writer.println("HTTP/1.1 200 OK");
            writer.println("Content-Type: text/html; charset=UTF-8");
            writer.println("Connection: close");
            writer.println();
            writer.println("<!DOCTYPE html>");
            writer.println("<html><head><title>OAuth Login Success</title></head>");
            writer.println("<body><h1>OAuth Login Successful!</h1>");
            writer.println("<p>You have successfully logged in using OAuth. You may now close this tab.</p>");
            writer.println("</body></html>");
            writer.flush();
            LOGGER.debug("Sent success response to OAuth callback");
        }

        private void sendErrorResponse(PrintWriter writer, String error) {
            writer.println("HTTP/1.1 400 Bad Request");
            writer.println("Content-Type: text/html; charset=UTF-8");
            writer.println("Connection: close");
            writer.println();
            writer.println("<!DOCTYPE html>");
            writer.println("<html><head><title>OAuth Error</title></head>");
            writer.println("<body><h1>OAuth Error</h1>");
            writer.println("<p>" + error + "</p>");
            writer.println("</body></html>");
            writer.flush();
            LOGGER.warn("Sent error response to OAuth callback: {}", error);
        }

        private Map<String, String> parseQueryString(String query) {
            HashMap<String, String> params = new HashMap<String, String>();
            for (String pair : query.split("&")) {
                String[] keyValue = pair.split("=", 2);
                if (keyValue.length != 2) continue;
                try {
                    String key = URLDecoder.decode(keyValue[0], StandardCharsets.UTF_8);
                    String value = URLDecoder.decode(keyValue[1], StandardCharsets.UTF_8);
                    params.put(key, value);
                }
                catch (Exception e) {
                    LOGGER.warn("Failed to decode query parameter: {}", pair);
                }
            }
            return params;
        }

        public OAuthCallback waitForCallback(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
            return this.callbackFuture.get(timeout, unit);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void stop() {
            Object object = this.serverLock;
            synchronized (object) {
                if (this.serverSocket != null && !this.serverSocket.isClosed()) {
                    try {
                        this.serverSocket.close();
                        LOGGER.debug("OAuth callback server stopped");
                    }
                    catch (IOException e) {
                        LOGGER.warn("Error closing callback server: {}", e.getMessage());
                    }
                }
            }
        }
    }

    private static class OAuthConfig {
        private final String authorizationEndpoint;
        private final String tokenEndpoint;
        private final String issuer;

        public OAuthConfig(String authorizationEndpoint, String tokenEndpoint, String issuer) {
            this.authorizationEndpoint = authorizationEndpoint;
            this.tokenEndpoint = tokenEndpoint;
            this.issuer = issuer;
        }

        public String getAuthorizationEndpoint() {
            return this.authorizationEndpoint;
        }

        public String getTokenEndpoint() {
            return this.tokenEndpoint;
        }

        public String getIssuer() {
            return this.issuer;
        }
    }

    private static class PKCEChallenge {
        private final String verifier;
        private final String challenge;

        public PKCEChallenge(String verifier, String challenge) {
            this.verifier = verifier;
            this.challenge = challenge;
        }

        public String getVerifier() {
            return this.verifier;
        }

        public String getChallenge() {
            return this.challenge;
        }
    }

    private static class OAuthCallback {
        private final String code;
        private final String state;

        public OAuthCallback(String code, String state) {
            this.code = code;
            this.state = state;
        }

        public String getCode() {
            return this.code;
        }

        public String getState() {
            return this.state;
        }
    }
}

