/*
 * Decompiled with CFR 0.152.
 */
package io.r2dbc.postgresql.authentication;

import com.ongres.scram.client.ScramClient;
import com.ongres.scram.common.StringPreparation;
import com.ongres.scram.common.exception.ScramException;
import com.ongres.scram.common.util.TlsServerEndpoint;
import io.r2dbc.postgresql.authentication.AuthenticationHandler;
import io.r2dbc.postgresql.client.ConnectionContext;
import io.r2dbc.postgresql.message.backend.AuthenticationMessage;
import io.r2dbc.postgresql.message.backend.AuthenticationSASL;
import io.r2dbc.postgresql.message.backend.AuthenticationSASLContinue;
import io.r2dbc.postgresql.message.backend.AuthenticationSASLFinal;
import io.r2dbc.postgresql.message.frontend.FrontendMessage;
import io.r2dbc.postgresql.message.frontend.SASLInitialResponse;
import io.r2dbc.postgresql.message.frontend.SASLResponse;
import io.r2dbc.postgresql.util.Assert;
import io.r2dbc.postgresql.util.ByteBufferUtils;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import org.jspecify.annotations.Nullable;
import reactor.core.Exceptions;
import reactor.util.Logger;
import reactor.util.Loggers;

public class SASLAuthenticationHandler
implements AuthenticationHandler {
    private static final Logger LOG = Loggers.getLogger(SASLAuthenticationHandler.class);
    private final CharSequence password;
    private final String username;
    private final ConnectionContext context;
    private ScramClient scramClient;

    public SASLAuthenticationHandler(CharSequence password, String username, ConnectionContext context) {
        this.password = Assert.requireNonNull(password, "password must not be null");
        this.username = Assert.requireNonNull(username, "username must not be null");
        this.context = Assert.requireNonNull(context, "context must not be null");
    }

    public static boolean supports(AuthenticationMessage message) {
        Assert.requireNonNull(message, "message must not be null");
        return message instanceof AuthenticationSASL || message instanceof AuthenticationSASLContinue || message instanceof AuthenticationSASLFinal;
    }

    @Override
    public FrontendMessage handle(AuthenticationMessage message) {
        if (message instanceof AuthenticationSASL) {
            return this.handleAuthenticationSASL((AuthenticationSASL)message);
        }
        if (message instanceof AuthenticationSASLContinue) {
            return this.handleAuthenticationSASLContinue((AuthenticationSASLContinue)message);
        }
        if (message instanceof AuthenticationSASLFinal) {
            return this.handleAuthenticationSASLFinal((AuthenticationSASLFinal)message);
        }
        throw new IllegalArgumentException(String.format("Cannot handle %s message", message.getClass().getSimpleName()));
    }

    private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) {
        char[] password = new char[this.password.length()];
        for (int i = 0; i < password.length; ++i) {
            password[i] = this.password.charAt(i);
        }
        ScramClient.FinalBuildStage builder = ScramClient.builder().advertisedMechanisms(message.getAuthenticationMechanisms()).username(this.username).password(password).stringPreparation(StringPreparation.POSTGRESQL_PREPARATION);
        SSLSession sslSession = this.context.getSslSession();
        if (sslSession != null && sslSession.isValid()) {
            builder.channelBinding("tls-server-end-point", SASLAuthenticationHandler.extractSslEndpoint(sslSession));
        }
        this.scramClient = builder.build();
        return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), this.scramClient.getScramMechanism().getName());
    }

    private static byte[] extractSslEndpoint(SSLSession sslSession) {
        try {
            Certificate peerCert;
            Certificate[] certificates = sslSession.getPeerCertificates();
            if (certificates != null && certificates.length > 0 && (peerCert = certificates[0]) instanceof X509Certificate) {
                X509Certificate cert = (X509Certificate)peerCert;
                return TlsServerEndpoint.getChannelBindingData((X509Certificate)cert);
            }
        }
        catch (CertificateException | SSLException e) {
            LOG.debug("Cannot extract X509Certificate from SSL session", (Throwable)e);
        }
        return new byte[0];
    }

    private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) {
        try {
            this.scramClient.serverFirstMessage(ByteBufferUtils.decode(message.getData()));
            return new SASLResponse(ByteBufferUtils.encode(this.scramClient.clientFinalMessage().toString()));
        }
        catch (ScramException e) {
            throw Exceptions.propagate((Throwable)e);
        }
    }

    private @Nullable FrontendMessage handleAuthenticationSASLFinal(AuthenticationSASLFinal message) {
        try {
            this.scramClient.serverFinalMessage(ByteBufferUtils.decode(message.getAdditionalData()));
            return null;
        }
        catch (ScramException e) {
            throw Exceptions.propagate((Throwable)e);
        }
    }
}

