/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.mlkem;

import java.security.SecureRandom;
import org.bouncycastle.pqc.crypto.mlkem.MLKEMIndCpa;
import org.bouncycastle.pqc.crypto.mlkem.MLKEMPrivateKeyParameters;
import org.bouncycastle.pqc.crypto.mlkem.MLKEMPublicKeyParameters;
import org.bouncycastle.pqc.crypto.mlkem.PolyVec;
import org.bouncycastle.pqc.crypto.mlkem.Symmetric;
import org.bouncycastle.util.Arrays;

class MLKEMEngine {
    private SecureRandom random;
    private final MLKEMIndCpa indCpa;
    public static final int KyberN = 256;
    public static final int KyberQ = 3329;
    public static final int KyberQinv = 62209;
    public static final int KyberSymBytes = 32;
    private static final int KyberSharedSecretBytes = 32;
    public static final int KyberPolyBytes = 384;
    private static final int KyberEta2 = 2;
    private static final int KyberIndCpaMsgBytes = 32;
    private final int KyberK;
    private final int KyberPolyVecBytes;
    private final int KyberPolyCompressedBytes;
    private final int KyberPolyVecCompressedBytes;
    private final int KyberEta1;
    private final int KyberIndCpaPublicKeyBytes;
    private final int KyberIndCpaSecretKeyBytes;
    private final int KyberIndCpaBytes;
    private final int KyberPublicKeyBytes;
    private final int KyberSecretKeyBytes;
    private final int KyberCipherTextBytes;
    private final int CryptoBytes;
    private final int CryptoSecretKeyBytes;
    private final int CryptoPublicKeyBytes;
    private final int CryptoCipherTextBytes;
    private final int sessionKeyLength;
    private final Symmetric symmetric;

    public Symmetric getSymmetric() {
        return this.symmetric;
    }

    public static int getKyberEta2() {
        return 2;
    }

    public static int getKyberIndCpaMsgBytes() {
        return 32;
    }

    public int getCryptoCipherTextBytes() {
        return this.CryptoCipherTextBytes;
    }

    public int getCryptoPublicKeyBytes() {
        return this.CryptoPublicKeyBytes;
    }

    public int getCryptoSecretKeyBytes() {
        return this.CryptoSecretKeyBytes;
    }

    public int getCryptoBytes() {
        return this.CryptoBytes;
    }

    public int getKyberCipherTextBytes() {
        return this.KyberCipherTextBytes;
    }

    public int getKyberSecretKeyBytes() {
        return this.KyberSecretKeyBytes;
    }

    public int getKyberIndCpaPublicKeyBytes() {
        return this.KyberIndCpaPublicKeyBytes;
    }

    public int getKyberIndCpaSecretKeyBytes() {
        return this.KyberIndCpaSecretKeyBytes;
    }

    public int getKyberIndCpaBytes() {
        return this.KyberIndCpaBytes;
    }

    public int getKyberPublicKeyBytes() {
        return this.KyberPublicKeyBytes;
    }

    public int getKyberPolyCompressedBytes() {
        return this.KyberPolyCompressedBytes;
    }

    public int getKyberK() {
        return this.KyberK;
    }

    public int getKyberPolyVecBytes() {
        return this.KyberPolyVecBytes;
    }

    public int getKyberPolyVecCompressedBytes() {
        return this.KyberPolyVecCompressedBytes;
    }

    public int getKyberEta1() {
        return this.KyberEta1;
    }

    public MLKEMEngine(int k) {
        this.KyberK = k;
        switch (k) {
            case 2: {
                this.KyberEta1 = 3;
                this.KyberPolyCompressedBytes = 128;
                this.KyberPolyVecCompressedBytes = k * 320;
                this.sessionKeyLength = 32;
                break;
            }
            case 3: {
                this.KyberEta1 = 2;
                this.KyberPolyCompressedBytes = 128;
                this.KyberPolyVecCompressedBytes = k * 320;
                this.sessionKeyLength = 32;
                break;
            }
            case 4: {
                this.KyberEta1 = 2;
                this.KyberPolyCompressedBytes = 160;
                this.KyberPolyVecCompressedBytes = k * 352;
                this.sessionKeyLength = 32;
                break;
            }
            default: {
                throw new IllegalArgumentException("K: " + k + " is not supported for Crystals Kyber");
            }
        }
        this.KyberPolyVecBytes = k * 384;
        this.KyberIndCpaPublicKeyBytes = this.KyberPolyVecBytes + 32;
        this.KyberIndCpaSecretKeyBytes = this.KyberPolyVecBytes;
        this.KyberIndCpaBytes = this.KyberPolyVecCompressedBytes + this.KyberPolyCompressedBytes;
        this.KyberPublicKeyBytes = this.KyberIndCpaPublicKeyBytes;
        this.KyberSecretKeyBytes = this.KyberIndCpaSecretKeyBytes + this.KyberIndCpaPublicKeyBytes + 64;
        this.KyberCipherTextBytes = this.KyberIndCpaBytes;
        this.CryptoBytes = 32;
        this.CryptoSecretKeyBytes = this.KyberSecretKeyBytes;
        this.CryptoPublicKeyBytes = this.KyberPublicKeyBytes;
        this.CryptoCipherTextBytes = this.KyberCipherTextBytes;
        this.symmetric = new Symmetric.ShakeSymmetric();
        this.indCpa = new MLKEMIndCpa(this);
    }

    public void init(SecureRandom random) {
        this.random = random;
    }

    boolean checkModulus(byte[] t) {
        return PolyVec.checkModulus(this, t) < 0;
    }

    public byte[][] generateKemKeyPair() {
        byte[] d = new byte[32];
        byte[] z = new byte[32];
        this.random.nextBytes(d);
        this.random.nextBytes(z);
        return this.generateKemKeyPairInternal(d, z);
    }

    public byte[][] generateKemKeyPairInternal(byte[] d, byte[] z) {
        byte[][] indCpaKeyPair = this.indCpa.generateKeyPair(d);
        byte[] s = new byte[this.KyberIndCpaSecretKeyBytes];
        System.arraycopy(indCpaKeyPair[1], 0, s, 0, this.KyberIndCpaSecretKeyBytes);
        byte[] hashedPublicKey = new byte[32];
        this.symmetric.hash_h(hashedPublicKey, indCpaKeyPair[0], 0);
        byte[] outputPublicKey = new byte[this.KyberIndCpaPublicKeyBytes];
        System.arraycopy(indCpaKeyPair[0], 0, outputPublicKey, 0, this.KyberIndCpaPublicKeyBytes);
        return new byte[][]{Arrays.copyOfRange(outputPublicKey, 0, outputPublicKey.length - 32), Arrays.copyOfRange(outputPublicKey, outputPublicKey.length - 32, outputPublicKey.length), s, hashedPublicKey, z, Arrays.concatenate(d, z)};
    }

    byte[][] kemEncrypt(MLKEMPublicKeyParameters publicKey, byte[] randBytes) {
        byte[] publicKeyInput = publicKey.getEncoded();
        byte[] buf = new byte[64];
        byte[] kr = new byte[64];
        System.arraycopy(randBytes, 0, buf, 0, 32);
        this.symmetric.hash_h(buf, publicKeyInput, 32);
        this.symmetric.hash_g(kr, buf);
        byte[] outputCipherText = this.indCpa.encrypt(publicKeyInput, Arrays.copyOfRange(buf, 0, 32), Arrays.copyOfRange(kr, 32, kr.length));
        byte[] outputSharedSecret = new byte[this.sessionKeyLength];
        System.arraycopy(kr, 0, outputSharedSecret, 0, outputSharedSecret.length);
        byte[][] outBuf = new byte[][]{outputSharedSecret, outputCipherText};
        return outBuf;
    }

    byte[] kemDecrypt(MLKEMPrivateKeyParameters privateKey, byte[] cipherText) {
        byte[] secretKey = privateKey.getEncoded();
        byte[] buf = new byte[64];
        byte[] kr = new byte[64];
        byte[] publicKey = Arrays.copyOfRange(secretKey, this.KyberIndCpaSecretKeyBytes, secretKey.length);
        System.arraycopy(this.indCpa.decrypt(secretKey, cipherText), 0, buf, 0, 32);
        System.arraycopy(secretKey, this.KyberSecretKeyBytes - 64, buf, 32, 32);
        this.symmetric.hash_g(kr, buf);
        byte[] implicit_rejection = new byte[32 + this.KyberCipherTextBytes];
        System.arraycopy(secretKey, this.KyberSecretKeyBytes - 32, implicit_rejection, 0, 32);
        System.arraycopy(cipherText, 0, implicit_rejection, 32, this.KyberCipherTextBytes);
        this.symmetric.kdf(implicit_rejection, implicit_rejection);
        byte[] cmp = this.indCpa.encrypt(publicKey, Arrays.copyOfRange(buf, 0, 32), Arrays.copyOfRange(kr, 32, kr.length));
        boolean fail = !Arrays.constantTimeAreEqual(cipherText, cmp);
        this.cmov(kr, implicit_rejection, 32, fail);
        return Arrays.copyOfRange(kr, 0, this.sessionKeyLength);
    }

    private void cmov(byte[] r, byte[] x, int xlen, boolean b) {
        if (b) {
            System.arraycopy(x, 0, r, 0, xlen);
        } else {
            System.arraycopy(r, 0, r, 0, xlen);
        }
    }
}

