/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.bike;

import java.security.SecureRandom;
import org.bouncycastle.crypto.digests.SHA3Digest;
import org.bouncycastle.crypto.digests.SHAKEDigest;
import org.bouncycastle.pqc.crypto.bike.BIKERandomGenerator;
import org.bouncycastle.pqc.crypto.bike.Utils;
import org.bouncycastle.pqc.math.linearalgebra.GF2mField;
import org.bouncycastle.pqc.math.linearalgebra.PolynomialGF2mSmallM;
import org.bouncycastle.util.Arrays;

/*
 * Multiple versions of this class in jar - see https://www.benf.org/other/cfr/multi-version-jar.html
 */
class BIKEEngine {
    private int r;
    private int w;
    private int hw;
    private int t;
    private int l;
    private int nbIter;
    private int tau;
    private GF2mField field;
    private final PolynomialGF2mSmallM reductionPoly;
    private int L_BYTE;
    private int R_BYTE;

    public BIKEEngine(int r, int w, int t, int l, int nbIter, int tau) {
        GF2mField field;
        this.r = r;
        this.w = w;
        this.t = t;
        this.l = l;
        this.nbIter = nbIter;
        this.tau = tau;
        this.hw = this.w / 2;
        this.L_BYTE = l / 8;
        this.R_BYTE = (r + 7) / 8;
        this.field = field = new GF2mField(1);
        PolynomialGF2mSmallM poly = new PolynomialGF2mSmallM(field, r);
        this.reductionPoly = poly.addMonomial(0);
    }

    public int getSessionKeySize() {
        return this.L_BYTE;
    }

    private byte[] functionH(byte[] seed) {
        SHAKEDigest digest = new SHAKEDigest(256);
        digest.update(seed, 0, seed.length);
        byte[] wlist = BIKERandomGenerator.generateRandomByteArray(this.r * 2, 2 * this.R_BYTE, this.t, digest);
        return wlist;
    }

    private byte[] functionL(byte[] e0, byte[] e1) {
        byte[] hashRes = new byte[48];
        byte[] res = new byte[this.L_BYTE];
        SHA3Digest digest = new SHA3Digest(384);
        digest.update(e0, 0, e0.length);
        digest.update(e1, 0, e1.length);
        digest.doFinal(hashRes, 0);
        System.arraycopy(hashRes, 0, res, 0, this.L_BYTE);
        return res;
    }

    private byte[] functionK(byte[] m, byte[] c0, byte[] c1) {
        byte[] hashRes = new byte[48];
        byte[] res = new byte[this.L_BYTE];
        SHA3Digest digest = new SHA3Digest(384);
        digest.update(m, 0, m.length);
        digest.update(c0, 0, c0.length);
        digest.update(c1, 0, c1.length);
        digest.doFinal(hashRes, 0);
        System.arraycopy(hashRes, 0, res, 0, this.L_BYTE);
        return res;
    }

    public void genKeyPair(byte[] h0, byte[] h1, byte[] sigma, byte[] h, SecureRandom random) {
        byte[] seeds = new byte[64];
        random.nextBytes(seeds);
        byte[] seed1 = new byte[this.L_BYTE];
        byte[] seed2 = new byte[this.L_BYTE];
        System.arraycopy(seeds, 0, seed1, 0, seed1.length);
        System.arraycopy(seeds, seed1.length, seed2, 0, seed2.length);
        SHAKEDigest digest = new SHAKEDigest(256);
        digest.update(seed1, 0, seed1.length);
        byte[] h0Tmp = BIKERandomGenerator.generateRandomByteArray(this.r, this.R_BYTE, this.hw, digest);
        byte[] h1Tmp = BIKERandomGenerator.generateRandomByteArray(this.r, this.R_BYTE, this.hw, digest);
        System.arraycopy(h0Tmp, 0, h0, 0, h0.length);
        System.arraycopy(h1Tmp, 0, h1, 0, h1.length);
        byte[] h1Bits = new byte[this.r];
        byte[] h0Bits = new byte[this.r];
        Utils.fromByteArrayToBitArray(h0Bits, h0Tmp);
        Utils.fromByteArrayToBitArray(h1Bits, h1Tmp);
        byte[] h0Cut = Utils.removeLast0Bits(h0Bits);
        byte[] h1Cut = Utils.removeLast0Bits(h1Bits);
        PolynomialGF2mSmallM h0Poly = new PolynomialGF2mSmallM(this.field, h0Cut);
        PolynomialGF2mSmallM h1Poly = new PolynomialGF2mSmallM(this.field, h1Cut);
        PolynomialGF2mSmallM h0Inv = h0Poly.modInverseBigDeg(this.reductionPoly);
        PolynomialGF2mSmallM hPoly = h1Poly.modKaratsubaMultiplyBigDeg(h0Inv, this.reductionPoly);
        byte[] hTmp = hPoly.getEncoded();
        byte[] hByte = new byte[this.R_BYTE];
        Utils.fromBitArrayToByteArray(hByte, hTmp);
        System.arraycopy(hByte, 0, h, 0, h.length);
        System.arraycopy(seed2, 0, sigma, 0, sigma.length);
    }

    public void encaps(byte[] c0, byte[] c1, byte[] k, byte[] h, SecureRandom random) {
        byte[] seeds = new byte[64];
        random.nextBytes(seeds);
        byte[] m = new byte[this.L_BYTE];
        System.arraycopy(seeds, 0, m, 0, m.length);
        byte[] eBytes = this.functionH(m);
        byte[] eBits = new byte[2 * this.r];
        Utils.fromByteArrayToBitArray(eBits, eBytes);
        byte[] e0Bits = Arrays.copyOfRange(eBits, 0, this.r);
        byte[] e1Bits = Arrays.copyOfRange(eBits, this.r, eBits.length);
        byte[] e0Cut = Utils.removeLast0Bits(e0Bits);
        byte[] e1Cut = Utils.removeLast0Bits(e1Bits);
        PolynomialGF2mSmallM e0 = new PolynomialGF2mSmallM(this.field, e0Cut);
        PolynomialGF2mSmallM e1 = new PolynomialGF2mSmallM(this.field, e1Cut);
        byte[] h0Bits = new byte[this.r];
        Utils.fromByteArrayToBitArray(h0Bits, h);
        PolynomialGF2mSmallM hPoly = new PolynomialGF2mSmallM(this.field, Utils.removeLast0Bits(h0Bits));
        PolynomialGF2mSmallM c0Poly = e0.add(e1.modKaratsubaMultiplyBigDeg(hPoly, this.reductionPoly));
        byte[] c0Bits = c0Poly.getEncoded();
        byte[] c0Bytes = new byte[this.R_BYTE];
        Utils.fromBitArrayToByteArray(c0Bytes, c0Bits);
        System.arraycopy(c0Bytes, 0, c0, 0, c0.length);
        byte[] e0Bytes = new byte[this.R_BYTE];
        Utils.fromBitArrayToByteArray(e0Bytes, e0Bits);
        byte[] e1Bytes = new byte[this.R_BYTE];
        Utils.fromBitArrayToByteArray(e1Bytes, e1Bits);
        byte[] tmp = this.functionL(e0Bytes, e1Bytes);
        byte[] c1Tmp = Utils.xorBytes(m, tmp, this.L_BYTE);
        System.arraycopy(c1Tmp, 0, c1, 0, c1.length);
        byte[] kTmp = this.functionK(m, c0, c1);
        System.arraycopy(kTmp, 0, k, 0, kTmp.length);
    }

    public void decaps(byte[] k, byte[] h0, byte[] h1, byte[] sigma, byte[] c0, byte[] c1) {
        byte[] c0Bits = new byte[this.r];
        byte[] h0Bits = new byte[this.r];
        byte[] sigmaBits = new byte[this.l];
        Utils.fromByteArrayToBitArray(c0Bits, c0);
        Utils.fromByteArrayToBitArray(h0Bits, h0);
        Utils.fromByteArrayToBitArray(sigmaBits, sigma);
        byte[] c0Cut = Utils.removeLast0Bits(c0Bits);
        byte[] h0Cut = Utils.removeLast0Bits(h0Bits);
        int[] h0Compact = new int[this.hw];
        int[] h1Compact = new int[this.hw];
        this.convertToCompact(h0Compact, h0);
        this.convertToCompact(h1Compact, h1);
        byte[] syndrome = this.computeSyndrome(c0Cut, h0Cut);
        byte[] ePrimeBits = this.BGFDecoder(syndrome, h0Compact, h1Compact);
        byte[] ePrimeBytes = new byte[2 * this.R_BYTE];
        Utils.fromBitArrayToByteArray(ePrimeBytes, ePrimeBits);
        byte[] e0Bits = Arrays.copyOfRange(ePrimeBits, 0, this.r);
        byte[] e1Bits = Arrays.copyOfRange(ePrimeBits, this.r, ePrimeBits.length);
        byte[] e0Bytes = new byte[this.R_BYTE];
        Utils.fromBitArrayToByteArray(e0Bytes, e0Bits);
        byte[] e1Bytes = new byte[this.R_BYTE];
        Utils.fromBitArrayToByteArray(e1Bytes, e1Bits);
        byte[] mPrime = Utils.xorBytes(c1, this.functionL(e0Bytes, e1Bytes), this.L_BYTE);
        byte[] tmpK = new byte[this.l];
        byte[] wlist = this.functionH(mPrime);
        tmpK = Arrays.areEqual(ePrimeBytes, wlist) ? this.functionK(mPrime, c0, c1) : this.functionK(sigma, c0, c1);
        System.arraycopy(tmpK, 0, k, 0, tmpK.length);
    }

    private byte[] computeSyndrome(byte[] h0, byte[] c0) {
        PolynomialGF2mSmallM coPoly = new PolynomialGF2mSmallM(this.field, c0);
        PolynomialGF2mSmallM h0Poly = new PolynomialGF2mSmallM(this.field, h0);
        PolynomialGF2mSmallM s = coPoly.modKaratsubaMultiplyBigDeg(h0Poly, this.reductionPoly);
        byte[] transposedS = this.transpose(s.getEncoded());
        return transposedS;
    }

    private byte[] BGFDecoder(byte[] s, int[] h0Compact, int[] h1Compact) {
        byte[] e = new byte[2 * this.r];
        int[] h0CompactCol = this.getColumnFromCompactVersion(h0Compact);
        int[] h1CompactCol = this.getColumnFromCompactVersion(h1Compact);
        for (int i = 1; i <= this.nbIter; ++i) {
            byte[] black = new byte[2 * this.r];
            byte[] gray = new byte[2 * this.r];
            int T = this.threshold(Utils.getHammingWeight(s), i, this.r);
            this.BFIter(s, e, T, h0Compact, h1Compact, h0CompactCol, h1CompactCol, black, gray);
            if (i != 1) continue;
            this.BFMaskedIter(s, e, black, (this.hw + 1) / 2 + 1, h0Compact, h1Compact, h0CompactCol, h1CompactCol);
            this.BFMaskedIter(s, e, gray, (this.hw + 1) / 2 + 1, h0Compact, h1Compact, h0CompactCol, h1CompactCol);
        }
        if (Utils.getHammingWeight(s) == 0) {
            return e;
        }
        return null;
    }

    private byte[] transpose(byte[] in) {
        byte[] tmp = Utils.append0s(in, this.r);
        byte[] out = new byte[this.r];
        out[0] = tmp[0];
        for (int i = 1; i < this.r; ++i) {
            out[i] = tmp[this.r - i];
        }
        return out;
    }

    private void BFIter(byte[] s, byte[] e, int T, int[] h0Compact, int[] h1Compact, int[] h0CompactCol, int[] h1CompactCol, byte[] black, byte[] gray) {
        int j;
        int[] updatedIndices = new int[2 * this.r];
        for (j = 0; j < this.r; ++j) {
            if (this.ctr(h0CompactCol, s, j) >= T) {
                this.updateNewErrorIndex(e, j);
                updatedIndices[j] = 1;
                black[j] = 1;
                continue;
            }
            if (this.ctr(h0CompactCol, s, j) < T - this.tau) continue;
            gray[j] = 1;
        }
        for (j = 0; j < this.r; ++j) {
            if (this.ctr(h1CompactCol, s, j) >= T) {
                this.updateNewErrorIndex(e, this.r + j);
                updatedIndices[this.r + j] = 1;
                black[this.r + j] = 1;
                continue;
            }
            if (this.ctr(h1CompactCol, s, j) < T - this.tau) continue;
            gray[this.r + j] = 1;
        }
        for (int i = 0; i < 2 * this.r; ++i) {
            if (updatedIndices[i] != 1) continue;
            this.recomputeSyndrome(s, i, h0Compact, h1Compact);
        }
    }

    private void BFMaskedIter(byte[] s, byte[] e, byte[] mask, int T, int[] h0Compact, int[] h1Compact, int[] h0CompactCol, int[] h1CompactCol) {
        int j;
        int[] updatedIndices = new int[2 * this.r];
        for (j = 0; j < this.r; ++j) {
            if (this.ctr(h0CompactCol, s, j) < T || mask[j] != 1) continue;
            this.updateNewErrorIndex(e, j);
            updatedIndices[j] = 1;
        }
        for (j = 0; j < this.r; ++j) {
            if (this.ctr(h1CompactCol, s, j) < T || mask[this.r + j] != 1) continue;
            this.updateNewErrorIndex(e, this.r + j);
            updatedIndices[this.r + j] = 1;
        }
        for (int i = 0; i < 2 * this.r; ++i) {
            if (updatedIndices[i] != 1) continue;
            this.recomputeSyndrome(s, i, h0Compact, h1Compact);
        }
    }

    private int threshold(int hammingWeight, int i, int r) {
        double d = 0.0;
        int floorD = 0;
        int res = 0;
        switch (r) {
            case 12323: {
                d = 0.0069722 * (double)hammingWeight + 13.53;
                floorD = (int)Math.floor(d);
                res = floorD > 36 ? floorD : 36;
                break;
            }
            case 24659: {
                d = 0.005265 * (double)hammingWeight + 15.2588;
                floorD = (int)Math.floor(d);
                res = floorD > 52 ? floorD : 52;
                break;
            }
            case 40973: {
                d = 0.00402312 * (double)hammingWeight + 17.8785;
                floorD = (int)Math.floor(d);
                res = floorD > 69 ? floorD : 69;
            }
        }
        return res;
    }

    private int ctr(int[] hCompactCol, byte[] s, int j) {
        int count = 0;
        for (int i = 0; i < this.hw; ++i) {
            if (s[(hCompactCol[i] + j) % this.r] != 1) continue;
            ++count;
        }
        return count;
    }

    private void convertToCompact(int[] compactVersion, byte[] h) {
        int count = 0;
        for (int i = 0; i < this.R_BYTE; ++i) {
            for (int j = 0; j < 8 && i * 8 + j != this.r; ++j) {
                if ((h[i] >> j & 1) != 1) continue;
                compactVersion[count++] = i * 8 + j;
            }
        }
    }

    private int[] getColumnFromCompactVersion(int[] hCompact) {
        int[] hCompactColumn = new int[this.hw];
        if (hCompact[0] == 0) {
            hCompactColumn[0] = 0;
            for (int i = 1; i < this.hw; ++i) {
                hCompactColumn[i] = this.r - hCompact[this.hw - i];
            }
        } else {
            for (int i = 0; i < this.hw; ++i) {
                hCompactColumn[i] = this.r - hCompact[this.hw - 1 - i];
            }
        }
        return hCompactColumn;
    }

    private void recomputeSyndrome(byte[] syndrome, int index, int[] h0Compact, int[] h1Compact) {
        if (index < this.r) {
            for (int i = 0; i < this.hw; ++i) {
                if (h0Compact[i] <= index) {
                    int n = index - h0Compact[i];
                    syndrome[n] = (byte)(syndrome[n] ^ 1);
                    continue;
                }
                int n = this.r + index - h0Compact[i];
                syndrome[n] = (byte)(syndrome[n] ^ 1);
            }
        } else {
            for (int i = 0; i < this.hw; ++i) {
                if (h1Compact[i] <= index - this.r) {
                    int n = index - this.r - h1Compact[i];
                    syndrome[n] = (byte)(syndrome[n] ^ 1);
                    continue;
                }
                int n = this.r - h1Compact[i] + (index - this.r);
                syndrome[n] = (byte)(syndrome[n] ^ 1);
            }
        }
    }

    private void updateNewErrorIndex(byte[] e, int index) {
        int newIndex = index;
        if (index != 0 && index != this.r) {
            newIndex = index > this.r ? 2 * this.r - index + this.r : this.r - index;
        }
        int n = newIndex;
        e[n] = (byte)(e[n] ^ 1);
    }
}

