/*
 * Decompiled with CFR 0.152.
 */
package io.trino.server.security.oauth2;

import com.google.common.base.Strings;
import com.google.common.base.Verify;
import com.google.common.hash.Hashing;
import com.google.common.io.Resources;
import com.google.inject.Inject;
import io.airlift.log.Logger;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtBuilder;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.lang.NestedCollection;
import io.jsonwebtoken.security.Keys;
import io.trino.server.security.jwt.JwtUtil;
import io.trino.server.security.oauth2.ChallengeFailedException;
import io.trino.server.security.oauth2.ForRefreshTokens;
import io.trino.server.security.oauth2.NonceCookie;
import io.trino.server.security.oauth2.OAuth2Client;
import io.trino.server.security.oauth2.OAuth2Config;
import io.trino.server.security.oauth2.OAuth2TokenHandler;
import io.trino.server.security.oauth2.TokenPairSerializer;
import io.trino.server.ui.OAuth2WebUiInstalled;
import io.trino.server.ui.OAuthIdTokenCookie;
import io.trino.server.ui.OAuthWebUiCookie;
import jakarta.ws.rs.core.NewCookie;
import jakarta.ws.rs.core.Response;
import java.io.IOException;
import java.net.URI;
import java.net.URL;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.Date;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import javax.crypto.SecretKey;

public class OAuth2Service {
    private static final Logger LOG = Logger.get(OAuth2Service.class);
    public static final String STATE = "state";
    public static final String NONCE = "nonce";
    public static final String OPENID_SCOPE = "openid";
    private static final String STATE_AUDIENCE_UI = "trino_oauth_ui";
    private static final String FAILURE_REPLACEMENT_TEXT = "<!-- ERROR_MESSAGE -->";
    private static final Random SECURE_RANDOM = new SecureRandom();
    public static final String HANDLER_STATE_CLAIM = "handler_state";
    private final OAuth2Client client;
    private final Optional<Duration> tokenExpiration;
    private final TokenPairSerializer tokenPairSerializer;
    private final String successHtml;
    private final String failureHtml;
    private final TemporalAmount challengeTimeout;
    private final SecretKey stateHmac;
    private final JwtParser jwtParser;
    private final OAuth2TokenHandler tokenHandler;
    private final boolean webUiOAuthEnabled;

    @Inject
    public OAuth2Service(OAuth2Client client, OAuth2Config oauth2Config, OAuth2TokenHandler tokenHandler, TokenPairSerializer tokenPairSerializer, @ForRefreshTokens Optional<Duration> tokenExpiration, Optional<OAuth2WebUiInstalled> webUiOAuthEnabled) throws IOException {
        this.client = Objects.requireNonNull(client, "client is null");
        this.successHtml = Resources.toString((URL)Resources.getResource(this.getClass(), (String)"/oauth2/success.html"), (Charset)StandardCharsets.UTF_8);
        this.failureHtml = Resources.toString((URL)Resources.getResource(this.getClass(), (String)"/oauth2/failure.html"), (Charset)StandardCharsets.UTF_8);
        Verify.verify((boolean)this.failureHtml.contains(FAILURE_REPLACEMENT_TEXT), (String)"login.html does not contain the replacement text", (Object[])new Object[0]);
        this.challengeTimeout = Duration.ofMillis(oauth2Config.getChallengeTimeout().toMillis());
        this.stateHmac = Keys.hmacShaKeyFor((byte[])oauth2Config.getStateKey().map(key -> Hashing.sha256().hashString((CharSequence)key, StandardCharsets.UTF_8).asBytes()).orElseGet(() -> OAuth2Service.secureRandomBytes(32)));
        this.jwtParser = JwtUtil.newJwtParserBuilder().verifyWith(this.stateHmac).requireAudience(STATE_AUDIENCE_UI).build();
        this.tokenHandler = Objects.requireNonNull(tokenHandler, "tokenHandler is null");
        this.tokenPairSerializer = Objects.requireNonNull(tokenPairSerializer, "tokenPairSerializer is null");
        this.tokenExpiration = Objects.requireNonNull(tokenExpiration, "tokenExpiration is null");
        this.webUiOAuthEnabled = webUiOAuthEnabled.isPresent();
    }

    public Response startOAuth2Challenge(URI callbackUri, Optional<String> handlerState) {
        Instant challengeExpiration = Instant.now().plus(this.challengeTimeout);
        String state = ((JwtBuilder)((NestedCollection)JwtUtil.newJwtBuilder().signWith((Key)this.stateHmac).audience().add((Object)STATE_AUDIENCE_UI)).and()).claim(HANDLER_STATE_CLAIM, handlerState.orElse(null)).expiration(Date.from(challengeExpiration)).compact();
        OAuth2Client.Request request = this.client.createAuthorizationRequest(state, callbackUri);
        Response.ResponseBuilder response = Response.seeOther((URI)request.getAuthorizationUri());
        request.getNonce().ifPresent(nce -> response.cookie(new NewCookie[]{NonceCookie.create(nce, challengeExpiration)}));
        return response.build();
    }

    public Response handleOAuth2Error(String state, String error, String errorDescription, String errorUri) {
        try {
            Claims stateClaims = this.parseState(state);
            Optional.ofNullable((String)stateClaims.get(HANDLER_STATE_CLAIM, String.class)).ifPresent(value -> this.tokenHandler.setTokenExchangeError((String)value, String.format("Authentication response could not be verified: error=%s, errorDescription=%s, errorUri=%s", error, errorDescription, errorDescription)));
        }
        catch (ChallengeFailedException | RuntimeException e) {
            LOG.debug((Throwable)e, "Authentication response could not be verified invalid state: state=%s", new Object[]{state});
            return Response.status((Response.Status)Response.Status.BAD_REQUEST).entity((Object)this.getInternalFailureHtml("Authentication response could not be verified")).cookie(new NewCookie[]{NonceCookie.delete()}).build();
        }
        LOG.debug("OAuth server returned an error: error=%s, error_description=%s, error_uri=%s, state=%s", new Object[]{error, errorDescription, errorUri, state});
        return Response.ok().entity((Object)this.getCallbackErrorHtml(error)).cookie(new NewCookie[]{NonceCookie.delete()}).build();
    }

    public Response finishOAuth2Challenge(String state, String code, URI callbackUri, Optional<String> nonce) {
        Optional<String> handlerState;
        try {
            Claims stateClaims = this.parseState(state);
            handlerState = Optional.ofNullable((String)stateClaims.get(HANDLER_STATE_CLAIM, String.class));
        }
        catch (ChallengeFailedException | RuntimeException e) {
            LOG.debug((Throwable)e, "Authentication response could not be verified invalid state: state=%s", new Object[]{state});
            return Response.status((Response.Status)Response.Status.BAD_REQUEST).entity((Object)this.getInternalFailureHtml("Authentication response could not be verified")).cookie(new NewCookie[]{NonceCookie.delete()}).build();
        }
        try {
            OAuth2Client.Response oauth2Response = this.client.getOAuth2Response(code, callbackUri, nonce);
            Instant cookieExpirationTime = this.tokenExpiration.map(expiration -> Instant.now().plus((TemporalAmount)expiration)).orElse(oauth2Response.getExpiration());
            if (handlerState.isEmpty()) {
                Response.ResponseBuilder builder = Response.seeOther((URI)URI.create("/ui/")).cookie(OAuthWebUiCookie.create(this.tokenPairSerializer.serialize(TokenPairSerializer.TokenPair.fromOAuth2Response(oauth2Response)), cookieExpirationTime)).cookie(new NewCookie[]{NonceCookie.delete()});
                if (oauth2Response.getIdToken().isPresent()) {
                    builder.cookie(OAuthIdTokenCookie.create(oauth2Response.getIdToken().get(), cookieExpirationTime));
                }
                return builder.build();
            }
            this.tokenHandler.setAccessToken(handlerState.get(), this.tokenPairSerializer.serialize(TokenPairSerializer.TokenPair.fromOAuth2Response(oauth2Response)));
            Response.ResponseBuilder builder = Response.ok((Object)this.getSuccessHtml());
            if (this.webUiOAuthEnabled) {
                builder.cookie(OAuthWebUiCookie.create(this.tokenPairSerializer.serialize(TokenPairSerializer.TokenPair.fromOAuth2Response(oauth2Response)), cookieExpirationTime));
                if (oauth2Response.getIdToken().isPresent()) {
                    builder.cookie(OAuthIdTokenCookie.create(oauth2Response.getIdToken().get(), cookieExpirationTime));
                }
            }
            return builder.cookie(new NewCookie[]{NonceCookie.delete()}).build();
        }
        catch (ChallengeFailedException | RuntimeException e) {
            LOG.debug((Throwable)e, "Authentication response could not be verified: state=%s", new Object[]{state});
            handlerState.ifPresent(value -> this.tokenHandler.setTokenExchangeError((String)value, String.format("Authentication response could not be verified: state=%s", value)));
            return Response.status((Response.Status)Response.Status.BAD_REQUEST).cookie(new NewCookie[]{NonceCookie.delete()}).entity((Object)this.getInternalFailureHtml("Authentication response could not be verified")).build();
        }
    }

    private Claims parseState(String state) throws ChallengeFailedException {
        try {
            return (Claims)this.jwtParser.parseSignedClaims((CharSequence)state).getPayload();
        }
        catch (RuntimeException e) {
            throw new ChallengeFailedException("State validation failed", e);
        }
    }

    public String getSuccessHtml() {
        return this.successHtml;
    }

    public String getCallbackErrorHtml(String errorCode) {
        return this.failureHtml.replace(FAILURE_REPLACEMENT_TEXT, OAuth2Service.getOAuth2ErrorMessage(errorCode));
    }

    public String getInternalFailureHtml(String errorMessage) {
        return this.failureHtml.replace(FAILURE_REPLACEMENT_TEXT, Strings.nullToEmpty((String)errorMessage));
    }

    private static byte[] secureRandomBytes(int count) {
        byte[] bytes = new byte[count];
        SECURE_RANDOM.nextBytes(bytes);
        return bytes;
    }

    private static String getOAuth2ErrorMessage(String errorCode) {
        return switch (errorCode) {
            case "access_denied" -> "OAuth2 server denied the login";
            case "unauthorized_client" -> "OAuth2 server does not allow request from this Trino server";
            case "server_error" -> "OAuth2 server had a failure";
            case "temporarily_unavailable" -> "OAuth2 server is temporarily unavailable";
            default -> "OAuth2 unknown error code: " + errorCode;
        };
    }
}

