/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.slhdsa;

import java.math.BigInteger;
import java.util.LinkedList;
import org.bouncycastle.pqc.crypto.slhdsa.ADRS;
import org.bouncycastle.pqc.crypto.slhdsa.NodeEntry;
import org.bouncycastle.pqc.crypto.slhdsa.SIG_FORS;
import org.bouncycastle.pqc.crypto.slhdsa.SLHDSAEngine;
import org.bouncycastle.util.Arrays;

class Fors {
    SLHDSAEngine engine;

    public Fors(SLHDSAEngine engine) {
        this.engine = engine;
    }

    byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam) {
        if (s >>> z << z != s) {
            return null;
        }
        LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
        ADRS adrs = new ADRS(adrsParam);
        for (int idx = 0; idx < 1 << z; ++idx) {
            adrs.setTypeAndClear(6);
            adrs.setKeyPairAddress(adrsParam.getKeyPairAddress());
            adrs.setTreeHeight(0);
            adrs.setTreeIndex(s + idx);
            byte[] sk = this.engine.PRF(pkSeed, skSeed, adrs);
            adrs.changeType(3);
            byte[] node = this.engine.F(pkSeed, adrs, sk);
            adrs.setTreeHeight(1);
            int adrsTreeHeight = 1;
            int adrsTreeIndex = s + idx;
            while (!stack.isEmpty() && ((NodeEntry)stack.get((int)0)).nodeHeight == adrsTreeHeight) {
                adrsTreeIndex = (adrsTreeIndex - 1) / 2;
                adrs.setTreeIndex(adrsTreeIndex);
                NodeEntry current = (NodeEntry)stack.remove(0);
                node = this.engine.H(pkSeed, adrs, current.nodeValue, node);
                adrs.setTreeHeight(++adrsTreeHeight);
            }
            stack.add(0, new NodeEntry(node, adrsTreeHeight));
        }
        return ((NodeEntry)stack.get((int)0)).nodeValue;
    }

    public SIG_FORS[] sign(byte[] md, byte[] skSeed, byte[] pkSeed, ADRS paramAdrs) {
        ADRS adrs = new ADRS(paramAdrs);
        int[] idxs = Fors.base2B(md, this.engine.A, this.engine.K);
        SIG_FORS[] sig_fors = new SIG_FORS[this.engine.K];
        for (int i = 0; i < this.engine.K; ++i) {
            int idx = idxs[i];
            adrs.setTypeAndClear(6);
            adrs.setKeyPairAddress(paramAdrs.getKeyPairAddress());
            adrs.setTreeHeight(0);
            adrs.setTreeIndex((i << this.engine.A) + idx);
            byte[] sk = this.engine.PRF(pkSeed, skSeed, adrs);
            adrs.changeType(3);
            byte[][] authPath = new byte[this.engine.A][];
            for (int j = 0; j < this.engine.A; ++j) {
                int s = idx >>> j ^ 1;
                authPath[j] = this.treehash(skSeed, (i << this.engine.A) + (s << j), j, pkSeed, adrs);
            }
            sig_fors[i] = new SIG_FORS(sk, authPath);
        }
        return sig_fors;
    }

    public byte[] pkFromSig(SIG_FORS[] sig_fors, byte[] message, byte[] pkSeed, ADRS adrs) {
        byte[][] node = new byte[2][];
        byte[][] root = new byte[this.engine.K][];
        int[] idxs = Fors.base2B(message, this.engine.A, this.engine.K);
        for (int i = 0; i < this.engine.K; ++i) {
            int idx = idxs[i];
            byte[] sk = sig_fors[i].getSK();
            adrs.setTreeHeight(0);
            adrs.setTreeIndex((i << this.engine.A) + idx);
            node[0] = this.engine.F(pkSeed, adrs, sk);
            byte[][] authPath = sig_fors[i].getAuthPath();
            adrs.setTreeIndex((i << this.engine.A) + idx);
            for (int j = 0; j < this.engine.A; ++j) {
                adrs.setTreeHeight(j + 1);
                if ((idx & 1 << j) == 0) {
                    adrs.setTreeIndex(adrs.getTreeIndex() / 2);
                    node[1] = this.engine.H(pkSeed, adrs, node[0], authPath[j]);
                } else {
                    adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
                    node[1] = this.engine.H(pkSeed, adrs, authPath[j], node[0]);
                }
                node[0] = node[1];
            }
            root[i] = node[0];
        }
        ADRS forspkADRS = new ADRS(adrs);
        forspkADRS.setTypeAndClear(4);
        forspkADRS.setKeyPairAddress(adrs.getKeyPairAddress());
        return this.engine.T_l(pkSeed, forspkADRS, Arrays.concatenate(root));
    }

    static int[] base2B(byte[] msg, int b, int outLen) {
        int[] baseB = new int[outLen];
        int i = 0;
        int bits = 0;
        BigInteger total = BigInteger.ZERO;
        for (int o = 0; o < outLen; ++o) {
            while (bits < b) {
                total = total.shiftLeft(8).add(BigInteger.valueOf(msg[i] & 0xFF));
                ++i;
                bits += 8;
            }
            baseB[o] = total.shiftRight(bits -= b).mod(BigInteger.valueOf(2L).pow(b)).intValue();
        }
        return baseB;
    }
}

