/*
 * Decompiled with CFR 0.152.
 */
package net.snowflake.client.jdbc.internal.grpc.alts.internal;

import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import net.snowflake.client.jdbc.internal.google.common.annotations.VisibleForTesting;
import net.snowflake.client.jdbc.internal.google.common.base.Preconditions;
import net.snowflake.client.jdbc.internal.google.common.base.Strings;
import net.snowflake.client.jdbc.internal.google.protobuf.ByteString;
import net.snowflake.client.jdbc.internal.grpc.ChannelLogger;
import net.snowflake.client.jdbc.internal.grpc.Status;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.AltsChannelCrypter;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.AltsClientOptions;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.AltsHandshakerOptions;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.AltsHandshakerStub;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.AltsTsiFrameProtector;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.HandshakeProtocol;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.HandshakerReq;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.HandshakerResp;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.HandshakerResult;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.HandshakerServiceGrpc;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.HandshakerStatus;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.NextHandshakeMessageReq;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.ServerHandshakeParameters;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.StartClientHandshakeReq;
import net.snowflake.client.jdbc.internal.grpc.alts.internal.StartServerHandshakeReq;

class AltsHandshakerClient {
    private static final String APPLICATION_PROTOCOL = "net.snowflake.client.jdbc.internal.grpc";
    private static final String RECORD_PROTOCOL = "ALTSRP_GCM_AES128_REKEY";
    private static final int KEY_LENGTH = AltsChannelCrypter.getKeyLength();
    private final AltsHandshakerStub handshakerStub;
    private final AltsHandshakerOptions handshakerOptions;
    private HandshakerResult result;
    private HandshakerStatus status;
    private final ChannelLogger logger;
    private boolean closed = false;

    AltsHandshakerClient(HandshakerServiceGrpc.HandshakerServiceStub stub, AltsHandshakerOptions options, ChannelLogger logger) {
        this.handshakerStub = new AltsHandshakerStub(stub);
        this.handshakerOptions = options;
        this.logger = logger;
    }

    @VisibleForTesting
    AltsHandshakerClient(AltsHandshakerStub handshakerStub, AltsHandshakerOptions options, ChannelLogger logger) {
        this.handshakerStub = handshakerStub;
        this.handshakerOptions = options;
        this.logger = logger;
    }

    static String getApplicationProtocol() {
        return APPLICATION_PROTOCOL;
    }

    static String getRecordProtocol() {
        return RECORD_PROTOCOL;
    }

    private void setStartClientFields(HandshakerReq.Builder req) {
        StartClientHandshakeReq.Builder startClientReq = StartClientHandshakeReq.newBuilder().setHandshakeSecurityProtocol(HandshakeProtocol.ALTS).addApplicationProtocols(APPLICATION_PROTOCOL).addRecordProtocols(RECORD_PROTOCOL);
        if (this.handshakerOptions.getRpcProtocolVersions() != null) {
            startClientReq.setRpcVersions(this.handshakerOptions.getRpcProtocolVersions());
        }
        if (this.handshakerOptions instanceof AltsClientOptions) {
            AltsClientOptions clientOptions = (AltsClientOptions)this.handshakerOptions;
            if (!Strings.isNullOrEmpty(clientOptions.getTargetName())) {
                startClientReq.setTargetName(clientOptions.getTargetName());
            }
            for (String serviceAccount : clientOptions.getTargetServiceAccounts()) {
                startClientReq.addTargetIdentitiesBuilder().setServiceAccount(serviceAccount);
            }
        }
        startClientReq.setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize());
        req.setClientStart(startClientReq);
    }

    private void setStartServerFields(HandshakerReq.Builder req, ByteBuffer inBytes) {
        ServerHandshakeParameters serverParameters = ServerHandshakeParameters.newBuilder().addRecordProtocols(RECORD_PROTOCOL).build();
        StartServerHandshakeReq.Builder startServerReq = StartServerHandshakeReq.newBuilder().addApplicationProtocols(APPLICATION_PROTOCOL).putHandshakeParameters(HandshakeProtocol.ALTS.getNumber(), serverParameters).setInBytes(ByteString.copyFrom(inBytes.duplicate()));
        if (this.handshakerOptions.getRpcProtocolVersions() != null) {
            startServerReq.setRpcVersions(this.handshakerOptions.getRpcProtocolVersions());
        }
        startServerReq.setMaxFrameSize(AltsTsiFrameProtector.getMaxFrameSize());
        req.setServerStart(startServerReq);
    }

    public boolean isFinished() {
        if (this.result != null) {
            return true;
        }
        return this.status != null && this.status.getCode() != Status.Code.OK.value();
    }

    public HandshakerStatus getStatus() {
        return this.status;
    }

    public HandshakerResult getResult() {
        return this.result;
    }

    public byte[] getKey() {
        if (this.result == null) {
            return null;
        }
        if (this.result.getKeyData().size() < KEY_LENGTH) {
            throw new IllegalStateException("Could not get enough key data from the handshake.");
        }
        byte[] key = new byte[KEY_LENGTH];
        this.result.getKeyData().substring(0, KEY_LENGTH).copyTo(key, 0);
        return key;
    }

    private void handleResponse(HandshakerResp resp) throws GeneralSecurityException {
        this.status = resp.getStatus();
        if (resp.hasResult()) {
            this.result = resp.getResult();
            this.close();
        }
        if (this.status.getCode() != Status.Code.OK.value()) {
            String error = "Handshaker service error: " + this.status.getDetails();
            this.logger.log(ChannelLogger.ChannelLogLevel.DEBUG, error);
            this.close();
            throw new GeneralSecurityException(error);
        }
    }

    public ByteBuffer startClientHandshake() throws GeneralSecurityException {
        HandshakerResp resp;
        Preconditions.checkState(!this.isFinished(), "Handshake has already finished.");
        HandshakerReq.Builder req = HandshakerReq.newBuilder();
        this.setStartClientFields(req);
        try {
            this.logger.log(ChannelLogger.ChannelLogLevel.DEBUG, "Send ALTS handshake request to upstream");
            resp = this.handshakerStub.send(req.build());
            this.logger.log(ChannelLogger.ChannelLogLevel.DEBUG, "Receive ALTS handshake response from upstream");
        }
        catch (IOException | InterruptedException e) {
            throw new GeneralSecurityException(e);
        }
        this.handleResponse(resp);
        return resp.getOutFrames().asReadOnlyByteBuffer();
    }

    public ByteBuffer startServerHandshake(ByteBuffer inBytes) throws GeneralSecurityException {
        HandshakerResp resp;
        Preconditions.checkState(!this.isFinished(), "Handshake has already finished.");
        HandshakerReq.Builder req = HandshakerReq.newBuilder();
        this.setStartServerFields(req, inBytes);
        try {
            resp = this.handshakerStub.send(req.build());
        }
        catch (IOException | InterruptedException e) {
            throw new GeneralSecurityException(e);
        }
        this.handleResponse(resp);
        ((Buffer)inBytes).position(inBytes.position() + resp.getBytesConsumed());
        return resp.getOutFrames().asReadOnlyByteBuffer();
    }

    public ByteBuffer next(ByteBuffer inBytes) throws GeneralSecurityException {
        HandshakerResp resp;
        Preconditions.checkState(!this.isFinished(), "Handshake has already finished.");
        HandshakerReq.Builder req = HandshakerReq.newBuilder().setNext(NextHandshakeMessageReq.newBuilder().setInBytes(ByteString.copyFrom(inBytes.duplicate())).build());
        try {
            this.logger.log(ChannelLogger.ChannelLogLevel.DEBUG, "Send ALTS handshake request to upstream");
            resp = this.handshakerStub.send(req.build());
            this.logger.log(ChannelLogger.ChannelLogLevel.DEBUG, "Receive ALTS handshake response from upstream");
        }
        catch (IOException | InterruptedException e) {
            throw new GeneralSecurityException(e);
        }
        this.handleResponse(resp);
        ((Buffer)inBytes).position(inBytes.position() + resp.getBytesConsumed());
        return resp.getOutFrames().asReadOnlyByteBuffer();
    }

    public void close() {
        if (this.closed) {
            return;
        }
        this.closed = true;
        this.handshakerStub.close();
    }
}

