/*
 * Decompiled with CFR 0.152.
 */
package org.codelibs.jcifs.smb.internal.smb2;

import java.security.Key;
import java.security.SecureRandom;
import java.util.concurrent.atomic.AtomicLong;
import javax.crypto.Cipher;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import org.bouncycastle.crypto.BlockCipher;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.engines.AESEngine;
import org.bouncycastle.crypto.modes.AEADBlockCipher;
import org.bouncycastle.crypto.modes.CCMBlockCipher;
import org.bouncycastle.crypto.params.AEADParameters;
import org.bouncycastle.crypto.params.KeyParameter;
import org.codelibs.jcifs.smb.CIFSException;
import org.codelibs.jcifs.smb.DialectVersion;
import org.codelibs.jcifs.smb.internal.smb2.Smb2TransformHeader;

public class Smb2EncryptionContext {
    private final int cipherId;
    private final DialectVersion dialect;
    private final byte[] encryptionKey;
    private final byte[] decryptionKey;
    private final AtomicLong nonceCounter = new AtomicLong(0L);
    private final SecureRandom secureRandom = new SecureRandom();
    public static final int CIPHER_AES_128_CCM = 1;
    public static final int CIPHER_AES_128_GCM = 2;
    public static final int TRANSFORM_FLAG_ENCRYPTED = 1;

    public Smb2EncryptionContext(int cipherId, DialectVersion dialect, byte[] encryptionKey, byte[] decryptionKey) {
        this.cipherId = cipherId;
        this.dialect = dialect;
        this.encryptionKey = (byte[])encryptionKey.clone();
        this.decryptionKey = (byte[])decryptionKey.clone();
    }

    public int getCipherId() {
        return this.cipherId;
    }

    public DialectVersion getDialect() {
        return this.dialect;
    }

    public byte[] generateNonce() {
        byte[] nonce = new byte[16];
        long counter = this.nonceCounter.incrementAndGet();
        System.arraycopy(Smb2EncryptionContext.longToBytes(counter), 0, nonce, 0, 8);
        byte[] randomBytes = new byte[8];
        this.secureRandom.nextBytes(randomBytes);
        System.arraycopy(randomBytes, 0, nonce, 8, 8);
        return nonce;
    }

    public byte[] encryptMessage(byte[] message, long sessionId) throws CIFSException {
        try {
            byte[] authTag;
            byte[] ciphertext;
            byte[] nonce = this.generateNonce();
            int flags = this.getTransformFlags();
            Smb2TransformHeader transformHeader = new Smb2TransformHeader(nonce, message.length, flags, sessionId);
            byte[] associatedData = transformHeader.getAssociatedData();
            if (this.isGCMCipher()) {
                cipher = this.createGCMCipher(true, nonce);
                cipher.updateAAD(associatedData);
                byte[] encrypted = cipher.doFinal(message);
                int tagLength = this.getAuthTagLength();
                ciphertext = new byte[encrypted.length - tagLength];
                authTag = new byte[tagLength];
                System.arraycopy(encrypted, 0, ciphertext, 0, ciphertext.length);
                System.arraycopy(encrypted, ciphertext.length, authTag, 0, tagLength);
            } else {
                cipher = this.createCCMCipher(true, nonce, associatedData.length, message.length);
                byte[] input = new byte[associatedData.length + message.length];
                System.arraycopy(associatedData, 0, input, 0, associatedData.length);
                System.arraycopy(message, 0, input, associatedData.length, message.length);
                byte[] output = new byte[cipher.getOutputSize(input.length)];
                int len = cipher.processBytes(input, 0, input.length, output, 0);
                len += cipher.doFinal(output, len);
                int tagLength = this.getAuthTagLength();
                ciphertext = new byte[message.length];
                authTag = new byte[tagLength];
                System.arraycopy(output, associatedData.length, ciphertext, 0, message.length);
                System.arraycopy(output, output.length - tagLength, authTag, 0, tagLength);
            }
            transformHeader.setSignature(authTag);
            byte[] result = new byte[52 + ciphertext.length];
            transformHeader.encode(result, 0);
            System.arraycopy(ciphertext, 0, result, 52, ciphertext.length);
            return result;
        }
        catch (Exception e) {
            throw new CIFSException("Failed to encrypt message", e);
        }
    }

    public byte[] decryptMessage(byte[] encryptedMessage) throws CIFSException {
        try {
            byte[] plaintext;
            Smb2TransformHeader transformHeader = Smb2TransformHeader.decode(encryptedMessage, 0);
            byte[] associatedData = transformHeader.getAssociatedData();
            byte[] nonce = transformHeader.getNonce();
            byte[] authTag = transformHeader.getSignature();
            int ciphertextLength = encryptedMessage.length - 52;
            byte[] ciphertext = new byte[ciphertextLength];
            System.arraycopy(encryptedMessage, 52, ciphertext, 0, ciphertextLength);
            if (this.isGCMCipher()) {
                Cipher cipher = this.createGCMCipher(false, nonce);
                cipher.updateAAD(associatedData);
                byte[] input = new byte[ciphertext.length + authTag.length];
                System.arraycopy(ciphertext, 0, input, 0, ciphertext.length);
                System.arraycopy(authTag, 0, input, ciphertext.length, authTag.length);
                plaintext = cipher.doFinal(input);
            } else {
                AEADBlockCipher cipher = this.createCCMCipher(false, nonce, associatedData.length, ciphertext.length);
                byte[] input = new byte[associatedData.length + ciphertext.length + authTag.length];
                System.arraycopy(associatedData, 0, input, 0, associatedData.length);
                System.arraycopy(ciphertext, 0, input, associatedData.length, ciphertext.length);
                System.arraycopy(authTag, 0, input, associatedData.length + ciphertext.length, authTag.length);
                byte[] output = new byte[cipher.getOutputSize(input.length)];
                int len = cipher.processBytes(input, 0, input.length, output, 0);
                len += cipher.doFinal(output, len);
                plaintext = new byte[ciphertext.length];
                System.arraycopy(output, associatedData.length, plaintext, 0, ciphertext.length);
            }
            return plaintext;
        }
        catch (Exception e) {
            throw new CIFSException("Failed to decrypt message", e);
        }
    }

    private boolean isGCMCipher() {
        return this.cipherId == 2;
    }

    private int getKeyLength() {
        if (this.cipherId == 1 || this.cipherId == 2) {
            return 16;
        }
        throw new IllegalArgumentException("Unsupported cipher: " + this.cipherId);
    }

    private int getAuthTagLength() {
        return 16;
    }

    private int getTransformFlags() {
        if (this.dialect.atLeast(DialectVersion.SMB311)) {
            return 1;
        }
        return this.cipherId;
    }

    private Cipher createGCMCipher(boolean encrypt, byte[] nonce) throws Exception {
        String algorithm = "AES";
        SecretKeySpec keySpec = new SecretKeySpec(encrypt ? this.encryptionKey : this.decryptionKey, "AES");
        Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
        GCMParameterSpec gcmSpec = new GCMParameterSpec(this.getAuthTagLength() * 8, nonce);
        cipher.init(encrypt ? 1 : 2, (Key)keySpec, gcmSpec);
        return cipher;
    }

    private AEADBlockCipher createCCMCipher(boolean encrypt, byte[] nonce, int aadLength, int plaintextLength) {
        CCMBlockCipher cipher = new CCMBlockCipher((BlockCipher)new AESEngine());
        KeyParameter keyParam = new KeyParameter(encrypt ? this.encryptionKey : this.decryptionKey);
        byte[] adjustedNonce = new byte[13];
        System.arraycopy(nonce, 0, adjustedNonce, 0, Math.min(13, nonce.length));
        AEADParameters params = new AEADParameters(keyParam, this.getAuthTagLength() * 8, adjustedNonce, null);
        cipher.init(encrypt, (CipherParameters)params);
        return cipher;
    }

    private static byte[] longToBytes(long value) {
        byte[] bytes = new byte[8];
        for (int i = 0; i < 8; ++i) {
            bytes[i] = (byte)(value >>> 8 * (7 - i));
        }
        return bytes;
    }
}

