/*
 * Decompiled with CFR 0.152.
 */
package org.wildfly.security.sasl.otp;

import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.sasl.SaslException;
import org.wildfly.common.Assert;
import org.wildfly.security._private.ElytronMessages;
import org.wildfly.security.auth.callback.ExtendedChoiceCallback;
import org.wildfly.security.auth.callback.ParameterCallback;
import org.wildfly.security.password.spec.OneTimePasswordAlgorithmSpec;
import org.wildfly.security.sasl.otp.OTPUtil;
import org.wildfly.security.sasl.util.AbstractSaslClient;
import org.wildfly.security.sasl.util.StringPrep;
import org.wildfly.security.util.ByteStringBuilder;
import org.wildfly.security.util.CodePointIterator;

final class OTPSaslClient
extends AbstractSaslClient {
    private static final int ST_NEW = 1;
    private static final int ST_CHALLENGE_RESPONSE = 2;
    private final SecureRandom secureRandom;
    private final String[] alternateDictionary;
    private NameCallback nameCallback;
    private String userName;

    OTPSaslClient(String mechanismName, SecureRandom secureRandom, String[] alternateDictionary, String protocol, String serverName, CallbackHandler callbackHandler, String authorizationId) {
        super(mechanismName, protocol, serverName, callbackHandler, authorizationId, true);
        this.secureRandom = secureRandom;
        this.alternateDictionary = alternateDictionary;
    }

    @Override
    public void init() {
        this.setNegotiationState(1);
    }

    @Override
    protected byte[] evaluateMessage(int state, byte[] challenge) throws SaslException {
        switch (state) {
            case 1: {
                if (challenge != null && challenge.length != 0) {
                    throw ElytronMessages.log.mechInitialChallengeMustBeEmpty(this.getMechanismName()).toSaslException();
                }
                ByteStringBuilder response = new ByteStringBuilder();
                String authorizationId = this.getAuthorizationId();
                OTPUtil.validateAuthorizationId(authorizationId);
                this.nameCallback = authorizationId == null ? new NameCallback("User name") : new NameCallback("User name", authorizationId);
                this.handleCallbacks(this.nameCallback);
                this.userName = this.nameCallback.getName();
                OTPUtil.validateUserName(this.userName);
                if (authorizationId != null) {
                    StringPrep.encode(authorizationId, response, 16383L);
                }
                response.append((byte)0);
                StringPrep.encode(this.userName, response, 16383L);
                this.setNegotiationState(2);
                return response.toArray();
            }
            case 2: {
                String otp;
                CodePointIterator cpi = CodePointIterator.ofUtf8Bytes(challenge);
                CodePointIterator di = cpi.delimitedBy(32);
                String algorithm = di.drainToString();
                OTPUtil.validateAlgorithm(algorithm);
                OTPUtil.skipDelims(di, cpi);
                int sequenceNumber = Integer.parseInt(di.drainToString());
                OTPUtil.validateSequenceNumber(sequenceNumber);
                OTPUtil.skipDelims(di, cpi);
                String seed = di.drainToString();
                OTPUtil.validateSeed(seed);
                OTPUtil.skipDelims(di, cpi);
                if (!di.drainToString().startsWith("ext")) {
                    throw ElytronMessages.log.mechInvalidMessageReceived(this.getMechanismName()).toSaslException();
                }
                if (cpi.hasNext()) {
                    OTPUtil.skipDelims(di, cpi);
                    if (cpi.hasNext()) {
                        throw ElytronMessages.log.mechInvalidMessageReceived(this.getMechanismName()).toSaslException();
                    }
                }
                int defaultResponseTypeChoice = sequenceNumber < 10 ? OTPUtil.getResponseTypeChoiceIndex("init-word") : OTPUtil.getResponseTypeChoiceIndex("word");
                ExtendedChoiceCallback responseTypeChoiceCallback = new ExtendedChoiceCallback("One-time password response type", OTPUtil.RESPONSE_TYPES, defaultResponseTypeChoice, false, true);
                PasswordCallback passwordCallback = new PasswordCallback("Pass phrase", false);
                this.handleCallbacks(this.nameCallback, responseTypeChoiceCallback, passwordCallback);
                String responseType = responseTypeChoiceCallback.getSelectedIndexes() != null ? OTPUtil.RESPONSE_TYPES[responseTypeChoiceCallback.getSelectedIndexes()[0]] : OTPUtil.RESPONSE_TYPES[responseTypeChoiceCallback.getDefaultChoice()];
                char[] passPhraseChars = passwordCallback.getPassword();
                passwordCallback.clearPassword();
                if (passPhraseChars != null) {
                    String passPhrase = this.getPasswordFromPasswordChars(passPhraseChars);
                    OTPUtil.validatePassPhrase(passPhrase);
                    if (seed.equals(passPhrase)) {
                        throw ElytronMessages.log.mechOTPPassPhraseAndSeedMustNotMatch().toSaslException();
                    }
                    otp = OTPUtil.formatOTP(OTPUtil.generateOTP(algorithm, passPhrase, seed, sequenceNumber), responseType, this.alternateDictionary);
                } else {
                    ParameterCallback parameterCallback = new ParameterCallback(OneTimePasswordAlgorithmSpec.class);
                    parameterCallback.setParameterSpec(new OneTimePasswordAlgorithmSpec(algorithm, seed.getBytes(StandardCharsets.US_ASCII), sequenceNumber));
                    passwordCallback = new PasswordCallback("One-time password", false);
                    this.handleCallbacks(this.nameCallback, responseTypeChoiceCallback, parameterCallback, passwordCallback);
                    responseType = responseTypeChoiceCallback.getSelectedIndexes() != null ? OTPUtil.RESPONSE_TYPES[responseTypeChoiceCallback.getSelectedIndexes()[0]] : OTPUtil.RESPONSE_TYPES[responseTypeChoiceCallback.getDefaultChoice()];
                    otp = this.getOTP(passwordCallback);
                }
                this.negotiationComplete();
                return this.createOTPResponse(algorithm, seed, otp, responseType);
            }
        }
        throw Assert.impossibleSwitchCase((int)state);
    }

    @Override
    public void dispose() throws SaslException {
    }

    private byte[] createOTPResponse(String algorithm, String seed, String otp, String responseType) throws SaslException {
        ByteStringBuilder response = new ByteStringBuilder();
        response.append(responseType);
        response.append(':');
        switch (responseType) {
            case "hex": 
            case "word": {
                response.append(otp);
                break;
            }
            case "init-hex": 
            case "init-word": {
                String newOTP;
                String newAlgorithm;
                int newSequenceNumber;
                Random random;
                String newSeed;
                response.append(otp);
                response.append(':');
                while ((newSeed = OTPUtil.generateRandomAlphanumericString(10, random = this.secureRandom != null ? this.secureRandom : ThreadLocalRandom.current())).equals(seed)) {
                }
                PasswordCallback passwordCallback = new PasswordCallback("New pass phrase", false);
                this.handleCallbacks(this.nameCallback, passwordCallback);
                char[] newPassPhraseChars = passwordCallback.getPassword();
                passwordCallback.clearPassword();
                if (newPassPhraseChars != null) {
                    newSequenceNumber = 499;
                    newAlgorithm = algorithm;
                    String newPassPhrase = this.getPasswordFromPasswordChars(newPassPhraseChars);
                    OTPUtil.validatePassPhrase(newPassPhrase);
                    if (newSeed.equals(newPassPhrase)) {
                        throw ElytronMessages.log.mechOTPPassPhraseAndSeedMustNotMatch().toSaslException();
                    }
                    newOTP = OTPUtil.formatOTP(OTPUtil.generateOTP(newAlgorithm, newPassPhrase, newSeed, newSequenceNumber), responseType, this.alternateDictionary);
                } else {
                    OneTimePasswordAlgorithmSpec defaultAlgorithmSpec = new OneTimePasswordAlgorithmSpec(algorithm, newSeed.getBytes(StandardCharsets.US_ASCII), 499);
                    ParameterCallback parameterCallback = new ParameterCallback(defaultAlgorithmSpec, OneTimePasswordAlgorithmSpec.class);
                    passwordCallback = new PasswordCallback("New one-time password", false);
                    this.handleCallbacks(this.nameCallback, parameterCallback, passwordCallback);
                    newOTP = this.getOTP(passwordCallback);
                    OneTimePasswordAlgorithmSpec algorithmSpec = (OneTimePasswordAlgorithmSpec)parameterCallback.getParameterSpec();
                    if (algorithmSpec == null) {
                        throw ElytronMessages.log.mechNoPasswordGiven(this.getMechanismName()).toSaslException();
                    }
                    newAlgorithm = algorithmSpec.getAlgorithm();
                    OTPUtil.validateAlgorithm(newAlgorithm);
                    newSequenceNumber = algorithmSpec.getSequenceNumber();
                    OTPUtil.validateSequenceNumber(newSequenceNumber);
                    newSeed = new String(algorithmSpec.getSeed(), StandardCharsets.US_ASCII);
                    OTPUtil.validateSeed(newSeed);
                }
                response.append(this.createInitResponse(newAlgorithm, newSeed, newSequenceNumber, newOTP));
                break;
            }
            default: {
                throw ElytronMessages.log.mechInvalidOTPResponseType().toSaslException();
            }
        }
        return response.toArray();
    }

    private ByteStringBuilder createInitResponse(String newAlgorithm, String newSeed, int newSequenceNumber, String newOTP) throws SaslException {
        String newDigestAlgorithm;
        ByteStringBuilder initResponse = new ByteStringBuilder();
        try {
            newDigestAlgorithm = OTPUtil.messageDigestAlgorithm(newAlgorithm);
        }
        catch (NoSuchAlgorithmException e) {
            throw ElytronMessages.log.mechInvalidOTPAlgorithm(newAlgorithm).toSaslException();
        }
        initResponse.append(newDigestAlgorithm);
        initResponse.append(' ');
        initResponse.appendNumber(newSequenceNumber);
        initResponse.append(' ');
        initResponse.append(newSeed);
        initResponse.append(':');
        initResponse.append(newOTP);
        return initResponse;
    }

    private String getOTP(PasswordCallback passwordCallback) throws SaslException {
        char[] passwordChars = passwordCallback.getPassword();
        passwordCallback.clearPassword();
        if (passwordChars != null) {
            return this.getPasswordFromPasswordChars(passwordChars);
        }
        throw ElytronMessages.log.mechNoPasswordGiven(this.getMechanismName()).toSaslException();
    }

    private String getPasswordFromPasswordChars(char[] passwordChars) {
        ByteStringBuilder b = new ByteStringBuilder();
        StringPrep.encode(passwordChars, b, 16383L);
        Arrays.fill(passwordChars, '\u0000');
        return new String(b.toArray(), StandardCharsets.UTF_8);
    }
}

