/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.mlkem;

import org.bouncycastle.pqc.crypto.mlkem.MLKEMEngine;
import org.bouncycastle.pqc.crypto.mlkem.Poly;
import org.bouncycastle.pqc.crypto.mlkem.PolyVec;
import org.bouncycastle.pqc.crypto.mlkem.Symmetric;
import org.bouncycastle.util.Arrays;

/*
 * Multiple versions of this class in jar - see https://www.benf.org/other/cfr/multi-version-jar.html
 */
class MLKEMIndCpa {
    private final MLKEMEngine engine;
    private final int kyberK;
    private final int indCpaPublicKeyBytes;
    private final int polyVecBytes;
    private final int indCpaBytes;
    private final int polyVecCompressedBytes;
    private final int polyCompressedBytes;
    private Symmetric symmetric;
    public final int KyberGenerateMatrixNBlocks;

    public MLKEMIndCpa(MLKEMEngine engine) {
        this.engine = engine;
        this.kyberK = engine.getKyberK();
        this.indCpaPublicKeyBytes = engine.getKyberPublicKeyBytes();
        this.polyVecBytes = engine.getKyberPolyVecBytes();
        this.indCpaBytes = engine.getKyberIndCpaBytes();
        this.polyVecCompressedBytes = engine.getKyberPolyVecCompressedBytes();
        this.polyCompressedBytes = engine.getKyberPolyCompressedBytes();
        this.symmetric = engine.getSymmetric();
        this.KyberGenerateMatrixNBlocks = (472 + this.symmetric.xofBlockBytes) / this.symmetric.xofBlockBytes;
    }

    byte[][] generateKeyPair(byte[] d) {
        int i;
        PolyVec secretKey = new PolyVec(this.engine);
        PolyVec publicKey = new PolyVec(this.engine);
        PolyVec e = new PolyVec(this.engine);
        byte[] buf = new byte[64];
        this.symmetric.hash_g(buf, Arrays.append(d, (byte)this.kyberK));
        byte[] publicSeed = new byte[32];
        byte[] noiseSeed = new byte[32];
        System.arraycopy(buf, 0, publicSeed, 0, 32);
        System.arraycopy(buf, 32, noiseSeed, 0, 32);
        byte count = 0;
        PolyVec[] aMatrix = new PolyVec[this.kyberK];
        for (i = 0; i < this.kyberK; ++i) {
            aMatrix[i] = new PolyVec(this.engine);
        }
        this.generateMatrix(aMatrix, publicSeed, false);
        for (i = 0; i < this.kyberK; ++i) {
            secretKey.getVectorIndex(i).getEta1Noise(noiseSeed, count);
            count = (byte)(count + 1);
        }
        for (i = 0; i < this.kyberK; ++i) {
            e.getVectorIndex(i).getEta1Noise(noiseSeed, count);
            count = (byte)(count + 1);
        }
        secretKey.polyVecNtt();
        e.polyVecNtt();
        for (i = 0; i < this.kyberK; ++i) {
            PolyVec.pointwiseAccountMontgomery(publicKey.getVectorIndex(i), aMatrix[i], secretKey, this.engine);
            publicKey.getVectorIndex(i).convertToMont();
        }
        publicKey.addPoly(e);
        publicKey.reducePoly();
        return new byte[][]{this.packPublicKey(publicKey, publicSeed), this.packSecretKey(secretKey)};
    }

    public byte[] encrypt(byte[] publicKeyInput, byte[] msg, byte[] coins) {
        int i;
        byte nonce = 0;
        PolyVec sp = new PolyVec(this.engine);
        PolyVec publicKeyPolyVec = new PolyVec(this.engine);
        PolyVec errorPolyVector = new PolyVec(this.engine);
        PolyVec bp = new PolyVec(this.engine);
        PolyVec[] aMatrixTranspose = new PolyVec[this.engine.getKyberK()];
        Poly errorPoly = new Poly(this.engine);
        Poly v = new Poly(this.engine);
        Poly k = new Poly(this.engine);
        byte[] seed = this.unpackPublicKey(publicKeyPolyVec, publicKeyInput);
        k.fromMsg(msg);
        for (i = 0; i < this.kyberK; ++i) {
            aMatrixTranspose[i] = new PolyVec(this.engine);
        }
        this.generateMatrix(aMatrixTranspose, seed, true);
        for (i = 0; i < this.kyberK; ++i) {
            sp.getVectorIndex(i).getEta1Noise(coins, nonce);
            nonce = (byte)(nonce + 1);
        }
        for (i = 0; i < this.kyberK; ++i) {
            errorPolyVector.getVectorIndex(i).getEta2Noise(coins, nonce);
            nonce = (byte)(nonce + 1);
        }
        errorPoly.getEta2Noise(coins, nonce);
        sp.polyVecNtt();
        for (i = 0; i < this.kyberK; ++i) {
            PolyVec.pointwiseAccountMontgomery(bp.getVectorIndex(i), aMatrixTranspose[i], sp, this.engine);
        }
        PolyVec.pointwiseAccountMontgomery(v, publicKeyPolyVec, sp, this.engine);
        bp.polyVecInverseNttToMont();
        v.polyInverseNttToMont();
        bp.addPoly(errorPolyVector);
        v.addCoeffs(errorPoly);
        v.addCoeffs(k);
        bp.reducePoly();
        v.reduce();
        byte[] outputCipherText = this.packCipherText(bp, v);
        return outputCipherText;
    }

    private byte[] packCipherText(PolyVec b, Poly v) {
        byte[] outBuf = new byte[this.indCpaBytes];
        System.arraycopy(b.compressPolyVec(), 0, outBuf, 0, this.polyVecCompressedBytes);
        System.arraycopy(v.compressPoly(), 0, outBuf, this.polyVecCompressedBytes, this.polyCompressedBytes);
        return outBuf;
    }

    private void unpackCipherText(PolyVec b, Poly v, byte[] cipherText) {
        byte[] compressedPolyVecCipherText = Arrays.copyOfRange(cipherText, 0, this.engine.getKyberPolyVecCompressedBytes());
        b.decompressPolyVec(compressedPolyVecCipherText);
        byte[] compressedPolyCipherText = Arrays.copyOfRange(cipherText, this.engine.getKyberPolyVecCompressedBytes(), cipherText.length);
        v.decompressPoly(compressedPolyCipherText);
    }

    public byte[] packPublicKey(PolyVec publicKeyPolyVec, byte[] seed) {
        byte[] buf = new byte[this.indCpaPublicKeyBytes];
        System.arraycopy(publicKeyPolyVec.toBytes(), 0, buf, 0, this.polyVecBytes);
        System.arraycopy(seed, 0, buf, this.polyVecBytes, 32);
        return buf;
    }

    public byte[] unpackPublicKey(PolyVec publicKeyPolyVec, byte[] publicKey) {
        byte[] outputSeed = new byte[32];
        publicKeyPolyVec.fromBytes(publicKey);
        System.arraycopy(publicKey, this.polyVecBytes, outputSeed, 0, 32);
        return outputSeed;
    }

    public byte[] packSecretKey(PolyVec secretKeyPolyVec) {
        return secretKeyPolyVec.toBytes();
    }

    public void unpackSecretKey(PolyVec secretKeyPolyVec, byte[] secretKey) {
        secretKeyPolyVec.fromBytes(secretKey);
    }

    public void generateMatrix(PolyVec[] aMatrix, byte[] seed, boolean transposed) {
        byte[] buf = new byte[this.KyberGenerateMatrixNBlocks * this.symmetric.xofBlockBytes + 2];
        for (int i = 0; i < this.kyberK; ++i) {
            for (int j = 0; j < this.kyberK; ++j) {
                if (transposed) {
                    this.symmetric.xofAbsorb(seed, (byte)i, (byte)j);
                } else {
                    this.symmetric.xofAbsorb(seed, (byte)j, (byte)i);
                }
                this.symmetric.xofSqueezeBlocks(buf, 0, this.symmetric.xofBlockBytes * this.KyberGenerateMatrixNBlocks);
                int buflen = this.KyberGenerateMatrixNBlocks * this.symmetric.xofBlockBytes;
                for (int ctr = MLKEMIndCpa.rejectionSampling(aMatrix[i].getVectorIndex(j), 0, 256, buf, buflen); ctr < 256; ctr += MLKEMIndCpa.rejectionSampling(aMatrix[i].getVectorIndex(j), ctr, 256 - ctr, buf, buflen)) {
                    int off = buflen % 3;
                    for (int k = 0; k < off; ++k) {
                        buf[k] = buf[buflen - off + k];
                    }
                    this.symmetric.xofSqueezeBlocks(buf, off, this.symmetric.xofBlockBytes * 2);
                    buflen = off + this.symmetric.xofBlockBytes;
                }
            }
        }
    }

    private static int rejectionSampling(Poly outputBuffer, int coeffOff, int len, byte[] inpBuf, int inpBufLen) {
        int pos = 0;
        int ctr = 0;
        while (ctr < len && pos + 3 <= inpBufLen) {
            short val0 = (short)(((short)(inpBuf[pos] & 0xFF) >> 0 | (short)(inpBuf[pos + 1] & 0xFF) << 8) & 0xFFF);
            short val1 = (short)(((short)(inpBuf[pos + 1] & 0xFF) >> 4 | (short)(inpBuf[pos + 2] & 0xFF) << 4) & 0xFFF);
            pos += 3;
            if (val0 < 3329) {
                outputBuffer.setCoeffIndex(coeffOff + ctr, val0);
                ++ctr;
            }
            if (ctr >= len || val1 >= 3329) continue;
            outputBuffer.setCoeffIndex(coeffOff + ctr, val1);
            ++ctr;
        }
        return ctr;
    }

    public byte[] decrypt(byte[] secretKey, byte[] cipherText) {
        byte[] outputMessage = new byte[MLKEMEngine.getKyberIndCpaMsgBytes()];
        PolyVec bp = new PolyVec(this.engine);
        PolyVec secretKeyPolyVec = new PolyVec(this.engine);
        Poly v = new Poly(this.engine);
        Poly mp = new Poly(this.engine);
        this.unpackCipherText(bp, v, cipherText);
        this.unpackSecretKey(secretKeyPolyVec, secretKey);
        bp.polyVecNtt();
        PolyVec.pointwiseAccountMontgomery(mp, secretKeyPolyVec, bp, this.engine);
        mp.polyInverseNttToMont();
        mp.polySubtract(v);
        mp.reduce();
        outputMessage = mp.toMsg();
        return outputMessage;
    }
}

