/*
 * Decompiled with CFR 0.152.
 */
package com.oracle.bmc.auth.sasl;

import com.google.protobuf.ByteString;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.bmc.auth.BasicAuthenticationDetailsProvider;
import com.oracle.bmc.auth.ConfigurableRefreshOnNotAuthenticatedProvider;
import com.oracle.bmc.auth.sasl.OciAuthProviderCallback;
import com.oracle.bmc.auth.sasl.OciMechanism;
import com.oracle.bmc.http.signing.internal.PEMFileRSAPrivateKeySupplier;
import com.oracle.bmc.http.signing.internal.SignatureSigner;
import com.oracle.bmc.identity.auth.sasl.messages.OciSaslMessages;
import com.oracle.bmc.util.StreamUtils;
import com.oracle.bmc.util.internal.Validate;
import java.beans.ConstructorProperties;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.interfaces.RSAPrivateKey;
import java.time.Duration;
import java.time.OffsetDateTime;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslException;

public class OciSaslClient
implements SaslClient {
    public static final int MIN_CHALLENGE_SIZE = 32;
    public static final int MAX_CHALLENGE_SIZE = 256;
    private static final SignatureSigner SIGNER = new SignatureSigner();
    private final OciMechanism mechanism;
    private final BasicAuthenticationDetailsProvider authProvider;
    private final String intent;
    private OciPrivateKey currentPrivateKey = null;
    private State state = State.KEY_ID;

    private OciSaslClient(OciMechanism mechanism, BasicAuthenticationDetailsProvider authProvider, String intent) {
        this.mechanism = mechanism;
        this.authProvider = authProvider;
        this.intent = intent;
    }

    @Override
    public String getMechanismName() {
        return this.mechanism.mechanismName();
    }

    @Override
    public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
        switch (this.state) {
            case KEY_ID: {
                this.state = State.SIGNING;
                return this.generateKeyMessage().toByteArray();
            }
            case SIGNING: {
                OciSaslMessages.Response response = this.signChallenge(challenge);
                this.state = State.COMPLETE;
                return response.toByteArray();
            }
        }
        return new byte[0];
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private OciSaslMessages.Key generateKeyMessage() {
        BasicAuthenticationDetailsProvider basicAuthenticationDetailsProvider = this.authProvider;
        synchronized (basicAuthenticationDetailsProvider) {
            if (this.authProvider instanceof ConfigurableRefreshOnNotAuthenticatedProvider) {
                ((ConfigurableRefreshOnNotAuthenticatedProvider)this.authProvider).refreshIfExpiringWithin(Duration.ofMinutes(5L));
            }
            if (this.currentPrivateKey != null) {
                StreamUtils.closeQuietly((InputStream)this.currentPrivateKey.privateKey);
                this.currentPrivateKey = null;
            }
            this.currentPrivateKey = new OciPrivateKey(this.authProvider.getKeyId(), this.authProvider.getPrivateKey(), this.authProvider.getPassphraseCharacters());
            return OciSaslMessages.Key.newBuilder().setKeyId(this.currentPrivateKey.keyId).setIntent(this.intent).build();
        }
    }

    private OciSaslMessages.Response signChallenge(byte[] serializedChallenge) throws SaslException {
        Validate.isTrue((this.currentPrivateKey != null ? 1 : 0) != 0, (String)"required: currentPrivateKey != null", (Object[])new Object[0]);
        OciSaslMessages.Challenge challenge = this.getAndValidateChallenge(serializedChallenge);
        long epoch = OffsetDateTime.now().toEpochSecond();
        PEMFileRSAPrivateKeySupplier keySupplier = new PEMFileRSAPrivateKeySupplier(this.currentPrivateKey.privateKey, this.currentPrivateKey.passphraseCharacters);
        this.currentPrivateKey = null;
        RSAPrivateKey privateKey = (RSAPrivateKey)keySupplier.supplyKey().orElseThrow(() -> new SaslException("Unable to get private key"));
        byte[] intentBytes = this.intent.getBytes(StandardCharsets.UTF_8);
        int messageToSignLength = challenge.getChallenge().toByteArray().length + intentBytes.length + 8;
        ByteBuffer messageToSign = ByteBuffer.allocate(messageToSignLength);
        messageToSign.put(challenge.getChallenge().toByteArray());
        messageToSign.put(intentBytes);
        messageToSign.putLong(epoch);
        byte[] signedMessage = SIGNER.sign(privateKey, messageToSign.array(), this.mechanism.algorithm().getJvmName());
        return OciSaslMessages.Response.newBuilder().setTime(epoch).setSignature(ByteString.copyFrom((byte[])signedMessage)).build();
    }

    private OciSaslMessages.Challenge getAndValidateChallenge(byte[] data) throws SaslException {
        try {
            OciSaslMessages.Challenge challenge = OciSaslMessages.Challenge.parseFrom(data);
            int challengeSize = challenge.getChallenge().size();
            if (challengeSize < 32 || challengeSize > 256) {
                throw new SaslException("Challenge sent by the server doesn't have the right size - " + challengeSize);
            }
            return challenge;
        }
        catch (InvalidProtocolBufferException exc) {
            throw new SaslException("Challenge sent by the server is invalid", exc);
        }
    }

    @Override
    public boolean hasInitialResponse() {
        return true;
    }

    @Override
    public boolean isComplete() {
        return this.state == State.COMPLETE;
    }

    @Override
    public byte[] unwrap(byte[] incoming, int offset, int len) {
        return new byte[0];
    }

    @Override
    public byte[] wrap(byte[] outgoing, int offset, int len) {
        return new byte[0];
    }

    @Override
    public Object getNegotiatedProperty(String propName) {
        return null;
    }

    @Override
    public void dispose() {
    }

    private static final class OciPrivateKey {
        private final String keyId;
        private final InputStream privateKey;
        private final char[] passphraseCharacters;

        @ConstructorProperties(value={"keyId", "privateKey", "passphraseCharacters"})
        public OciPrivateKey(String keyId, InputStream privateKey, char[] passphraseCharacters) {
            this.keyId = keyId;
            this.privateKey = privateKey;
            this.passphraseCharacters = passphraseCharacters;
        }
    }

    static class AuthProviderCache {
        private static final Map<String, BasicAuthenticationDetailsProvider> authProvidersCache = new ConcurrentHashMap<String, BasicAuthenticationDetailsProvider>();

        AuthProviderCache() {
        }

        static String cache(BasicAuthenticationDetailsProvider authProvider) {
            String key = UUID.randomUUID().toString();
            authProvidersCache.put(key, authProvider);
            return key;
        }

        static BasicAuthenticationDetailsProvider get(String key) {
            return authProvidersCache.get(key);
        }
    }

    public static class OciSaslClientFactory
    implements SaslClientFactory {
        @Override
        public SaslClient createSaslClient(String[] mechanisms, String authorizationId, String protocol, String serverName, Map<String, ?> props, CallbackHandler cbh) throws SaslException {
            String mechanism;
            OciMechanism ociMechanism = null;
            String[] stringArray = mechanisms;
            int n = stringArray.length;
            for (int i = 0; i < n && (ociMechanism = OciMechanism.fromMechanismName(mechanism = stringArray[i])) == null; ++i) {
            }
            if (ociMechanism == null) {
                throw new SaslException(String.format("Requested mechanisms '%s' not supported. Supported mechanisms are '%s'.", Collections.singletonList(mechanisms), OciMechanism.mechanismNames()));
            }
            Credentials credentials = this.getCredentials(cbh);
            return new OciSaslClient(ociMechanism, credentials.authProvider, credentials.payload);
        }

        @Override
        public String[] getMechanismNames(Map<String, ?> props) {
            Collection<String> mechanisms = OciMechanism.mechanismNames();
            return mechanisms.toArray(new String[0]);
        }

        private Credentials getCredentials(CallbackHandler callbackHandler) throws SaslException {
            NameCallback nameCallback = new NameCallback("Payload");
            PasswordCallback passwordCallback = new PasswordCallback("AuthProviderKey", false);
            OciAuthProviderCallback authProviderCallback = new OciAuthProviderCallback();
            OciSaslClientFactory.execute(callbackHandler, nameCallback, true);
            OciSaslClientFactory.execute(callbackHandler, passwordCallback, false);
            OciSaslClientFactory.execute(callbackHandler, authProviderCallback, false);
            if (authProviderCallback.authProvider() == null && passwordCallback.getPassword() == null) {
                throw new SaslException("Callback handler needs to support either PasswordCallback or OciAuthProviderCallback");
            }
            BasicAuthenticationDetailsProvider authProvider = authProviderCallback.authProvider() != null ? authProviderCallback.authProvider() : AuthProviderCache.get(new String(passwordCallback.getPassword()));
            String intent = nameCallback.getName();
            return new Credentials(authProvider, intent);
        }

        static <T extends Callback> void execute(CallbackHandler callbackHandler, T callback, boolean required) throws SaslException {
            try {
                callbackHandler.handle(new Callback[]{callback});
            }
            catch (UnsupportedCallbackException exc) {
                if (required) {
                    throw new SaslException(exc.getCallback().getClass().getSimpleName() + " is not supported by the callback handler", exc);
                }
            }
            catch (IOException exc) {
                throw new SaslException("Unexpected IOException during callback handler", exc);
            }
        }

        private static final class Credentials {
            private final BasicAuthenticationDetailsProvider authProvider;
            private final String payload;

            private Credentials(BasicAuthenticationDetailsProvider authProvider, String payload) {
                this.authProvider = authProvider;
                this.payload = payload;
            }
        }
    }

    private static enum State {
        KEY_ID,
        SIGNING,
        COMPLETE;

    }
}

