/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.sphincsplus;

import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.Xof;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.digests.SHA512Digest;
import org.bouncycastle.crypto.digests.SHAKEDigest;
import org.bouncycastle.crypto.generators.MGF1BytesGenerator;
import org.bouncycastle.crypto.macs.HMac;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.crypto.params.MGFParameters;
import org.bouncycastle.pqc.crypto.sphincsplus.ADRS;
import org.bouncycastle.pqc.crypto.sphincsplus.IndexedDigest;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.Pack;

abstract class SPHINCSPlusEngine {
    final boolean robust;
    final int N;
    final int WOTS_W;
    final int WOTS_LOGW;
    final int WOTS_LEN;
    final int WOTS_LEN1;
    final int WOTS_LEN2;
    final int D;
    final int A;
    final int K;
    final int H;
    final int H_PRIME;
    final int T;

    protected static byte[] xor(byte[] m, byte[] mask) {
        byte[] r = Arrays.clone(m);
        for (int t = 0; t < m.length; ++t) {
            int n = t;
            r[n] = (byte)(r[n] ^ mask[t]);
        }
        return r;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public SPHINCSPlusEngine(boolean robust, int n, int w, int d, int a, int k, int h) {
        this.N = n;
        if (w == 16) {
            this.WOTS_LOGW = 4;
            this.WOTS_LEN1 = 8 * this.N / this.WOTS_LOGW;
            if (this.N <= 8) {
                this.WOTS_LEN2 = 2;
            } else if (this.N <= 136) {
                this.WOTS_LEN2 = 3;
            } else {
                if (this.N > 256) throw new IllegalArgumentException("cannot precompute SPX_WOTS_LEN2 for n outside {2, .., 256}");
                this.WOTS_LEN2 = 4;
            }
        } else {
            if (w != 256) throw new IllegalArgumentException("wots_w assumed 16 or 256");
            this.WOTS_LOGW = 8;
            this.WOTS_LEN1 = 8 * this.N / this.WOTS_LOGW;
            if (this.N <= 1) {
                this.WOTS_LEN2 = 1;
            } else {
                if (this.N > 256) throw new IllegalArgumentException("cannot precompute SPX_WOTS_LEN2 for n outside {2, .., 256}");
                this.WOTS_LEN2 = 2;
            }
        }
        this.WOTS_W = w;
        this.WOTS_LEN = this.WOTS_LEN1 + this.WOTS_LEN2;
        this.robust = robust;
        this.D = d;
        this.A = a;
        this.K = k;
        this.H = h;
        this.H_PRIME = h / d;
        this.T = 1 << a;
    }

    abstract byte[] F(byte[] var1, ADRS var2, byte[] var3);

    abstract byte[] H(byte[] var1, ADRS var2, byte[] var3, byte[] var4);

    abstract IndexedDigest H_msg(byte[] var1, byte[] var2, byte[] var3, byte[] var4);

    abstract byte[] T_l(byte[] var1, ADRS var2, byte[] var3);

    abstract byte[] PRF(byte[] var1, byte[] var2, ADRS var3);

    abstract byte[] PRF_msg(byte[] var1, byte[] var2, byte[] var3);

    static class Sha256Engine
    extends SPHINCSPlusEngine {
        private final byte[] padding = new byte[64];
        private final Digest treeDigest = new SHA256Digest();
        private final byte[] digestBuf;
        private final HMac treeHMac;
        private final MGF1BytesGenerator mgf1;
        private final byte[] hmacBuf;
        private final Digest msgDigest;

        public Sha256Engine(boolean robust, int n, int w, int d, int a, int k, int h) {
            super(robust, n, w, d, a, k, h);
            if (n == 32) {
                this.msgDigest = new SHA512Digest();
                this.treeHMac = new HMac(new SHA512Digest());
                this.mgf1 = new MGF1BytesGenerator(new SHA512Digest());
            } else {
                this.msgDigest = new SHA256Digest();
                this.treeHMac = new HMac(new SHA256Digest());
                this.mgf1 = new MGF1BytesGenerator(new SHA256Digest());
            }
            this.digestBuf = new byte[this.treeDigest.getDigestSize()];
            this.hmacBuf = new byte[this.treeHMac.getMacSize()];
        }

        public byte[] F(byte[] pkSeed, ADRS adrs, byte[] m1) {
            byte[] compressedADRS = this.compressedADRS(adrs);
            if (this.robust) {
                m1 = this.bitmask256(Arrays.concatenate(pkSeed, compressedADRS), m1);
            }
            this.treeDigest.update(pkSeed, 0, pkSeed.length);
            this.treeDigest.update(this.padding, 0, 64 - pkSeed.length);
            this.treeDigest.update(compressedADRS, 0, compressedADRS.length);
            this.treeDigest.update(m1, 0, m1.length);
            this.treeDigest.doFinal(this.digestBuf, 0);
            return Arrays.copyOfRange(this.digestBuf, 0, this.N);
        }

        public byte[] H(byte[] pkSeed, ADRS adrs, byte[] m1, byte[] m2) {
            byte[] m1m2 = Arrays.concatenate(m1, m2);
            byte[] compressedADRS = this.compressedADRS(adrs);
            if (this.robust) {
                m1m2 = this.bitmask256(Arrays.concatenate(pkSeed, compressedADRS), m1m2);
            }
            this.treeDigest.update(pkSeed, 0, pkSeed.length);
            this.treeDigest.update(this.padding, 0, 64 - this.N);
            this.treeDigest.update(compressedADRS, 0, compressedADRS.length);
            this.treeDigest.update(m1m2, 0, m1m2.length);
            this.treeDigest.doFinal(this.digestBuf, 0);
            return Arrays.copyOfRange(this.digestBuf, 0, this.N);
        }

        IndexedDigest H_msg(byte[] prf, byte[] pkSeed, byte[] pkRoot, byte[] message) {
            int forsMsgBytes = (this.A * this.K + 7) / 8;
            int leafBits = this.H / this.D;
            int treeBits = this.H - leafBits;
            int leafBytes = (leafBits + 7) / 8;
            int treeBytes = (treeBits + 7) / 8;
            int m = forsMsgBytes + leafBytes + treeBytes;
            byte[] out = new byte[m];
            byte[] dig = new byte[this.msgDigest.getDigestSize()];
            this.msgDigest.update(prf, 0, prf.length);
            this.msgDigest.update(pkSeed, 0, pkSeed.length);
            this.msgDigest.update(pkRoot, 0, pkRoot.length);
            this.msgDigest.update(message, 0, message.length);
            this.msgDigest.doFinal(dig, 0);
            out = this.bitmask(Arrays.concatenate(prf, pkSeed, dig), out);
            byte[] treeIndexBuf = new byte[8];
            System.arraycopy(out, forsMsgBytes, treeIndexBuf, 8 - treeBytes, treeBytes);
            long treeIndex = Pack.bigEndianToLong(treeIndexBuf, 0);
            byte[] leafIndexBuf = new byte[4];
            System.arraycopy(out, forsMsgBytes + treeBytes, leafIndexBuf, 4 - leafBytes, leafBytes);
            int leafIndex = Pack.bigEndianToInt(leafIndexBuf, 0);
            return new IndexedDigest(treeIndex &= -1L >>> 64 - treeBits, leafIndex &= -1 >>> 32 - leafBits, Arrays.copyOfRange(out, 0, forsMsgBytes));
        }

        public byte[] T_l(byte[] pkSeed, ADRS adrs, byte[] m) {
            byte[] compressedADRS = this.compressedADRS(adrs);
            if (this.robust) {
                m = this.bitmask256(Arrays.concatenate(pkSeed, compressedADRS), m);
            }
            this.treeDigest.update(pkSeed, 0, pkSeed.length);
            this.treeDigest.update(this.padding, 0, 64 - this.N);
            this.treeDigest.update(compressedADRS, 0, compressedADRS.length);
            this.treeDigest.update(m, 0, m.length);
            this.treeDigest.doFinal(this.digestBuf, 0);
            return Arrays.copyOfRange(this.digestBuf, 0, this.N);
        }

        byte[] PRF(byte[] pkSeed, byte[] skSeed, ADRS adrs) {
            int n = skSeed.length;
            this.treeDigest.update(skSeed, 0, skSeed.length);
            byte[] compressedADRS = this.compressedADRS(adrs);
            this.treeDigest.update(compressedADRS, 0, compressedADRS.length);
            this.treeDigest.doFinal(this.digestBuf, 0);
            return Arrays.copyOfRange(this.digestBuf, 0, n);
        }

        public byte[] PRF_msg(byte[] prf, byte[] randomiser, byte[] message) {
            this.treeHMac.init(new KeyParameter(prf));
            this.treeHMac.update(randomiser, 0, randomiser.length);
            this.treeHMac.update(message, 0, message.length);
            this.treeHMac.doFinal(this.hmacBuf, 0);
            return Arrays.copyOfRange(this.hmacBuf, 0, this.N);
        }

        private byte[] compressedADRS(ADRS adrs) {
            byte[] rv = new byte[22];
            System.arraycopy(adrs.value, 3, rv, 0, 1);
            System.arraycopy(adrs.value, 8, rv, 1, 8);
            System.arraycopy(adrs.value, 19, rv, 9, 1);
            System.arraycopy(adrs.value, 20, rv, 10, 12);
            return rv;
        }

        protected byte[] bitmask(byte[] key, byte[] m) {
            byte[] mask = new byte[m.length];
            this.mgf1.init(new MGFParameters(key));
            this.mgf1.generateBytes(mask, 0, mask.length);
            for (int i = 0; i < m.length; ++i) {
                int n = i;
                mask[n] = (byte)(mask[n] ^ m[i]);
            }
            return mask;
        }

        protected byte[] bitmask256(byte[] key, byte[] m) {
            byte[] mask = new byte[m.length];
            MGF1BytesGenerator mgf1 = new MGF1BytesGenerator(new SHA256Digest());
            mgf1.init(new MGFParameters(key));
            mgf1.generateBytes(mask, 0, mask.length);
            for (int i = 0; i < m.length; ++i) {
                int n = i;
                mask[n] = (byte)(mask[n] ^ m[i]);
            }
            return mask;
        }
    }

    static class Shake256Engine
    extends SPHINCSPlusEngine {
        private final Xof treeDigest = new SHAKEDigest(256);

        public Shake256Engine(boolean robust, int n, int w, int d, int a, int k, int h) {
            super(robust, n, w, d, a, k, h);
        }

        byte[] F(byte[] pkSeed, ADRS adrs, byte[] m1) {
            byte[] mTheta = m1;
            if (this.robust) {
                mTheta = this.bitmask(pkSeed, adrs, m1);
            }
            byte[] rv = new byte[this.N];
            this.treeDigest.update(pkSeed, 0, pkSeed.length);
            this.treeDigest.update(adrs.value, 0, adrs.value.length);
            this.treeDigest.update(mTheta, 0, mTheta.length);
            this.treeDigest.doFinal(rv, 0, rv.length);
            return rv;
        }

        byte[] H(byte[] pkSeed, ADRS adrs, byte[] m1, byte[] m2) {
            byte[] m1m2 = Arrays.concatenate(m1, m2);
            if (this.robust) {
                m1m2 = this.bitmask(pkSeed, adrs, m1m2);
            }
            byte[] rv = new byte[this.N];
            this.treeDigest.update(pkSeed, 0, pkSeed.length);
            this.treeDigest.update(adrs.value, 0, adrs.value.length);
            this.treeDigest.update(m1m2, 0, m1m2.length);
            this.treeDigest.doFinal(rv, 0, rv.length);
            return rv;
        }

        IndexedDigest H_msg(byte[] R, byte[] pkSeed, byte[] pkRoot, byte[] message) {
            int forsMsgBytes = (this.A * this.K + 7) / 8;
            int leafBits = this.H / this.D;
            int treeBits = this.H - leafBits;
            int leafBytes = (leafBits + 7) / 8;
            int treeBytes = (treeBits + 7) / 8;
            int m = forsMsgBytes + leafBytes + treeBytes;
            byte[] out = new byte[m];
            this.treeDigest.update(R, 0, R.length);
            this.treeDigest.update(pkSeed, 0, pkSeed.length);
            this.treeDigest.update(pkRoot, 0, pkRoot.length);
            this.treeDigest.update(message, 0, message.length);
            this.treeDigest.doFinal(out, 0, out.length);
            byte[] treeIndexBuf = new byte[8];
            System.arraycopy(out, forsMsgBytes, treeIndexBuf, 8 - treeBytes, treeBytes);
            long treeIndex = Pack.bigEndianToLong(treeIndexBuf, 0);
            byte[] leafIndexBuf = new byte[4];
            System.arraycopy(out, forsMsgBytes + treeBytes, leafIndexBuf, 4 - leafBytes, leafBytes);
            int leafIndex = Pack.bigEndianToInt(leafIndexBuf, 0);
            return new IndexedDigest(treeIndex &= -1L >>> 64 - treeBits, leafIndex &= -1 >>> 32 - leafBits, Arrays.copyOfRange(out, 0, forsMsgBytes));
        }

        byte[] T_l(byte[] pkSeed, ADRS adrs, byte[] m) {
            byte[] mTheta = m;
            if (this.robust) {
                mTheta = this.bitmask(pkSeed, adrs, m);
            }
            byte[] rv = new byte[this.N];
            this.treeDigest.update(pkSeed, 0, pkSeed.length);
            this.treeDigest.update(adrs.value, 0, adrs.value.length);
            this.treeDigest.update(mTheta, 0, mTheta.length);
            this.treeDigest.doFinal(rv, 0, rv.length);
            return rv;
        }

        byte[] PRF(byte[] pkSeed, byte[] skSeed, ADRS adrs) {
            this.treeDigest.update(skSeed, 0, skSeed.length);
            this.treeDigest.update(adrs.value, 0, adrs.value.length);
            byte[] prf = new byte[this.N];
            this.treeDigest.doFinal(prf, 0, this.N);
            return prf;
        }

        public byte[] PRF_msg(byte[] prf, byte[] randomiser, byte[] message) {
            this.treeDigest.update(prf, 0, prf.length);
            this.treeDigest.update(randomiser, 0, randomiser.length);
            this.treeDigest.update(message, 0, message.length);
            byte[] out = new byte[this.N];
            this.treeDigest.doFinal(out, 0, out.length);
            return out;
        }

        protected byte[] bitmask(byte[] pkSeed, ADRS adrs, byte[] m) {
            byte[] mask = new byte[m.length];
            this.treeDigest.update(pkSeed, 0, pkSeed.length);
            this.treeDigest.update(adrs.value, 0, adrs.value.length);
            this.treeDigest.doFinal(mask, 0, mask.length);
            for (int i = 0; i < m.length; ++i) {
                int n = i;
                mask[n] = (byte)(mask[n] ^ m[i]);
            }
            return mask;
        }
    }
}

