/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.mls.TreeKEM;

import java.io.IOException;
import java.security.InvalidParameterException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.mls.TreeKEM.FilteredDirectPath;
import org.bouncycastle.mls.TreeKEM.LeafIndex;
import org.bouncycastle.mls.TreeKEM.LeafNode;
import org.bouncycastle.mls.TreeKEM.LeafNodeHashInput;
import org.bouncycastle.mls.TreeKEM.LeafNodeSource;
import org.bouncycastle.mls.TreeKEM.Node;
import org.bouncycastle.mls.TreeKEM.NodeIndex;
import org.bouncycastle.mls.TreeKEM.OptionalNode;
import org.bouncycastle.mls.TreeKEM.ParentHashInput;
import org.bouncycastle.mls.TreeKEM.ParentNode;
import org.bouncycastle.mls.TreeKEM.ParentNodeHashInput;
import org.bouncycastle.mls.TreeKEM.TreeHashInput;
import org.bouncycastle.mls.TreeKEM.TreeKEMPrivateKey;
import org.bouncycastle.mls.TreeKEM.Utils;
import org.bouncycastle.mls.TreeSize;
import org.bouncycastle.mls.codec.HPKECiphertext;
import org.bouncycastle.mls.codec.MLSInputStream;
import org.bouncycastle.mls.codec.MLSOutputStream;
import org.bouncycastle.mls.codec.UpdatePath;
import org.bouncycastle.mls.codec.UpdatePathNode;
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.crypto.Secret;
import org.bouncycastle.mls.protocol.Group;
import org.bouncycastle.util.Strings;
import org.bouncycastle.util.encoders.Hex;

public class TreeKEMPublicKey
implements MLSInputStream.Readable,
MLSOutputStream.Writable {
    MlsCipherSuite suite;
    TreeSize size;
    Map<NodeIndex, byte[]> hashes;
    private final Map<NodeIndex, byte[]> treeHashCache;
    private final Map<NodeIndex, Integer> exceptCache;
    ArrayList<OptionalNode> nodes;

    public MlsCipherSuite getSuite() {
        return this.suite;
    }

    public TreeSize getSize() {
        return this.size;
    }

    public static TreeKEMPublicKey clone(TreeKEMPublicKey other) throws IOException {
        TreeKEMPublicKey tree = (TreeKEMPublicKey)MLSInputStream.decode(MLSOutputStream.encode(other), TreeKEMPublicKey.class);
        tree.setSuite(other.suite);
        tree.setHashAll();
        return tree;
    }

    public TreeKEMPublicKey(MlsCipherSuite suite) throws IOException {
        this.suite = suite;
        this.hashes = new HashMap<NodeIndex, byte[]>();
        this.nodes = new ArrayList();
        this.treeHashCache = new HashMap<NodeIndex, byte[]>();
        this.exceptCache = new HashMap<NodeIndex, Integer>();
        this.size = TreeSize.forLeaves(0L);
        while (this.size.width() < (long)this.nodes.size()) {
            this.size = TreeSize.forLeaves(this.size.leafCount() * 2L);
        }
        while ((long)this.nodes.size() < this.size.width()) {
            this.nodes.add(OptionalNode.blankNode());
        }
    }

    public TreeKEMPublicKey(MLSInputStream stream) throws IOException {
        this.hashes = new HashMap<NodeIndex, byte[]>();
        this.nodes = new ArrayList();
        this.treeHashCache = new HashMap<NodeIndex, byte[]>();
        this.exceptCache = new HashMap<NodeIndex, Integer>();
        stream.readList(this.nodes, OptionalNode.class);
        this.size = TreeSize.forLeaves(1L);
        while (this.size.width() < (long)this.nodes.size()) {
            this.size = TreeSize.forLeaves(this.size.leafCount() * 2L);
        }
        while ((long)this.nodes.size() < this.size.width()) {
            this.nodes.add(OptionalNode.blankNode());
        }
    }

    @Override
    public void writeTo(MLSOutputStream stream) throws IOException {
        LeafIndex cut = new LeafIndex((int)this.size.leafCount() - 1);
        while (cut.value > 0 && this.nodeAt(cut).isBlank()) {
            --cut.value;
        }
        stream.writeList(this.nodes.subList(0, (int)new NodeIndex(cut).value() + 1));
    }

    public void setSuite(MlsCipherSuite suite) {
        this.suite = suite;
    }

    public String dumpHashes() {
        StringBuilder sb = new StringBuilder();
        for (NodeIndex n : this.hashes.keySet()) {
            sb.append(n.value()).append(" : ");
            sb.append(Hex.toHexString((byte[])this.hashes.get(n))).append(Strings.lineSeparator());
        }
        return sb.toString();
    }

    public String dump() {
        StringBuilder sb = new StringBuilder();
        sb.append("Tree:").append(Strings.lineSeparator());
        int i = 0;
        while ((long)i < this.size.width()) {
            NodeIndex index = new NodeIndex(i);
            sb.append(String.format("  %03d : ", i));
            if (!this.nodeAt(index).isBlank()) {
                byte[] pkRm = this.nodeAt((NodeIndex)index).node.getPublicKey();
                sb.append(Hex.toHexString((byte[])pkRm, (int)0, (int)4));
            } else {
                sb.append("        ");
            }
            sb.append("  | ");
            int j = 0;
            while ((long)j < index.level()) {
                sb.append("  ");
                ++j;
            }
            if (!this.nodeAt(index).isBlank()) {
                sb.append("X");
                if (!index.isLeaf()) {
                    ParentNode parent = this.nodeAt(index).getParentNode();
                    sb.append(" [");
                    for (LeafIndex u : parent.unmerged_leaves) {
                        sb.append(u.value).append(", ");
                    }
                    sb.append("]");
                }
            } else {
                sb.append("_");
            }
            sb.append(Strings.lineSeparator());
            ++i;
        }
        sb.append("nodeCount: ").append(this.nodes.size()).append(Strings.lineSeparator());
        return sb.toString();
    }

    public TreeKEMPrivateKey update(LeafIndex from, Secret leafSecret, byte[] groupId, byte[] sigPriv, Group.LeafNodeOptions options) throws Exception {
        OptionalNode leafNode = this.nodeAt(from);
        if (leafNode.isBlank()) {
            throw new Exception("Cannot update from blank node");
        }
        TreeKEMPrivateKey priv = TreeKEMPrivateKey.create(this, from, leafSecret);
        FilteredDirectPath dp = this.getFilteredDirectPath(new NodeIndex(from));
        ArrayList<UpdatePathNode> pathNodes = new ArrayList<UpdatePathNode>();
        for (NodeIndex n : dp.parents) {
            Secret pathSecret = priv.pathSecrets.get(n);
            AsymmetricCipherKeyPair nodePriv = priv.setPrivateKey(n, false);
            pathNodes.add(new UpdatePathNode(this.suite.getHPKE().serializePublicKey(nodePriv.getPublic()), new ArrayList<HPKECiphertext>()));
        }
        byte[][] ph = this.parentHashes(from, dp, pathNodes);
        byte[] ph0 = new byte[]{};
        if (ph.length != 0) {
            ph0 = ph[0];
        }
        byte[] leafPub = this.suite.getHPKE().serializePublicKey(priv.setPrivateKey(new NodeIndex(from), false).getPublic());
        LeafNode newLeaf = leafNode.getLeafNode().forCommit(this.suite, groupId, from, leafPub, ph0, options, sigPriv);
        this.merge(from, new UpdatePath(newLeaf, pathNodes));
        return priv;
    }

    public UpdatePath encap(TreeKEMPrivateKey priv, byte[] context, List<LeafIndex> except) throws Exception {
        FilteredDirectPath dp = this.getFilteredDirectPath(new NodeIndex(priv.index));
        ArrayList<UpdatePathNode> pathNodes = new ArrayList<UpdatePathNode>();
        for (int i = 0; i < dp.parents.size(); ++i) {
            NodeIndex n = dp.parents.get(i);
            List res = (List)dp.resolutions.get(i).clone();
            Utils.removeLeaves(res, except);
            Secret pathSecret = priv.pathSecrets.get(n);
            AsymmetricCipherKeyPair nodePriv = priv.setPrivateKey(n, false);
            ArrayList<HPKECiphertext> cts = new ArrayList<HPKECiphertext>();
            for (NodeIndex nr : res) {
                byte[] nodePub = this.nodeAt((NodeIndex)nr).node.getPublicKey();
                byte[][] ctAndEnc = this.suite.encryptWithLabel(nodePub, "UpdatePathNode", context, pathSecret.value());
                HPKECiphertext ct = new HPKECiphertext(ctAndEnc[1], ctAndEnc[0]);
                cts.add(ct);
            }
            pathNodes.add(new UpdatePathNode(this.suite.getHPKE().serializePublicKey(nodePriv.getPublic()), cts));
        }
        LeafNode newLeaf = this.getLeafNode(priv.index);
        return new UpdatePath(newLeaf, pathNodes);
    }

    public byte[] getRootHash() throws Exception {
        NodeIndex r = NodeIndex.root(this.size);
        if (!this.hashes.containsKey(r)) {
            throw new Exception("Root hash not set");
        }
        return this.hashes.get(r);
    }

    public void merge(LeafIndex from, UpdatePath path) throws Exception {
        this.nodeAt((LeafIndex)from).node = new Node(path.getLeafNode());
        FilteredDirectPath dp = this.getFilteredDirectPath(new NodeIndex(from));
        if (dp.parents.size() != path.getNodes().size()) {
            throw new Exception("Malformed direct path");
        }
        byte[][] ph = this.parentHashes(from, dp, path.getNodes());
        for (int i = 0; i < dp.parents.size(); ++i) {
            NodeIndex n = dp.parents.get(i);
            byte[] parentHash = new byte[]{};
            if (i < dp.parents.size() - 1) {
                parentHash = ph[i + 1];
            }
            this.nodeAt((NodeIndex)n).node = new Node(new ParentNode(path.getNodes().get(i).getEncryptionKey(), parentHash, new ArrayList<LeafIndex>()));
        }
        this.clearHashPath(from);
        this.setHashAll();
    }

    public int find(LeafNode leaf) {
        int i = 0;
        while ((long)i < this.size.leafCount()) {
            LeafIndex index = new LeafIndex(i);
            OptionalNode node = this.nodeAt(index);
            if (!node.isBlank() && node.isLeaf() && node.getLeafNode().equals(leaf)) {
                return i;
            }
            ++i;
        }
        return -1;
    }

    public boolean hasLeaf(LeafIndex index) {
        return !this.nodeAt(index).isBlank();
    }

    protected FilteredDirectPath getFilteredCommonDirectPath(LeafIndex leaf1, LeafIndex leaf2) throws Exception {
        FilteredDirectPath xPath = this.getFilteredDirectPath(new NodeIndex(leaf1));
        FilteredDirectPath yPath = this.getFilteredDirectPath(new NodeIndex(leaf2));
        xPath.reverse();
        yPath.reverse();
        FilteredDirectPath commonPath = new FilteredDirectPath();
        for (int i = 0; i < xPath.parents.size() && xPath.parents.get(i).value() == yPath.parents.get(i).value(); ++i) {
            commonPath.parents.add(xPath.parents.get(i));
            commonPath.resolutions.add(yPath.resolutions.get(i));
        }
        commonPath.reverse();
        return commonPath;
    }

    protected FilteredDirectPath getFilteredDirectPath(NodeIndex index) throws Exception {
        FilteredDirectPath fdp = new FilteredDirectPath();
        List<NodeIndex> cp = index.copath(this.size);
        for (NodeIndex n : cp) {
            NodeIndex p = n.parent();
            ArrayList<NodeIndex> res = this.resolve(n);
            if (res.isEmpty()) continue;
            fdp.parents.add(p);
            fdp.resolutions.add(res);
        }
        return fdp;
    }

    public LeafNode getLeafNode(LeafIndex index) {
        OptionalNode node = this.nodeAt(index);
        if (!node.isLeaf()) {
            return null;
        }
        return node.getLeafNode();
    }

    public ArrayList<NodeIndex> resolve(NodeIndex index) {
        boolean atLeaf;
        boolean bl = atLeaf = index.level() == 0L;
        if (!this.nodeAt(index).isBlank()) {
            ArrayList<NodeIndex> out = new ArrayList<NodeIndex>();
            out.add(index);
            if (index.isLeaf()) {
                return out;
            }
            OptionalNode node = this.nodeAt(index);
            List<LeafIndex> unmerged = node.getParentNode().unmerged_leaves;
            for (LeafIndex lindex : unmerged) {
                out.add(new NodeIndex(lindex));
            }
            return out;
        }
        if (atLeaf) {
            return new ArrayList<NodeIndex>();
        }
        ArrayList<NodeIndex> l = this.resolve(index.left());
        ArrayList<NodeIndex> r = this.resolve(index.right());
        l.addAll(r);
        return l;
    }

    public LeafIndex allocateLeaf() {
        LeafIndex index = new LeafIndex(0);
        while ((long)index.value < this.size.leafCount() && !this.nodeAt(index).isBlank()) {
            ++index.value;
        }
        if ((long)index.value >= this.size.leafCount()) {
            this.size = this.size.leafCount() == 0L ? TreeSize.forLeaves(1L) : TreeSize.forLeaves(this.size.leafCount() * 2L);
        }
        return index;
    }

    public LeafIndex addLeaf(LeafNode leaf) {
        LeafIndex index = new LeafIndex(0);
        while ((long)index.value < this.size.leafCount() && !this.nodeAt(index).isBlank()) {
            ++index.value;
        }
        if ((long)index.value >= this.size.leafCount()) {
            this.size = this.size.leafCount() == 0L ? TreeSize.forLeaves(1L) : TreeSize.forLeaves(this.size.leafCount() * 2L);
        }
        while ((long)this.nodes.size() < this.size.width()) {
            this.nodes.add(OptionalNode.blankNode());
        }
        this.nodeAt((LeafIndex)index).node = new Node(leaf);
        List<NodeIndex> dp = index.directPath(this.size);
        for (NodeIndex n : dp) {
            if (this.nodeAt((NodeIndex)n).node == null) continue;
            ParentNode parent = this.nodeAt(n).getParentNode();
            int insertPoint = this.upperBound(parent.unmerged_leaves, index);
            parent.unmerged_leaves.add(insertPoint, index);
        }
        this.clearHashPath(index);
        return index;
    }

    private void clearHashPath(LeafIndex index) {
        this.hashes.remove(new NodeIndex(index));
        for (NodeIndex n : index.directPath(this.size)) {
            this.hashes.remove(n);
        }
    }

    int upperBound(List<LeafIndex> list, LeafIndex index) {
        int lo = 0;
        int hi = list.size() - 1;
        while (lo <= hi) {
            int mid = (lo + hi) / 2;
            if (list.get((int)mid).value <= index.value) {
                lo = mid + 1;
                continue;
            }
            hi = mid - 1;
        }
        return lo;
    }

    public void updateLeaf(LeafIndex index, LeafNode leaf) {
        this.blankPath(index);
        this.nodeAt((LeafIndex)index).node = new Node(leaf);
        this.clearHashPath(index);
    }

    public void blankPath(LeafIndex index) {
        if (this.nodes.isEmpty()) {
            return;
        }
        NodeIndex ni = new NodeIndex(index);
        this.nodeAt((NodeIndex)ni).node = null;
        for (NodeIndex n : index.directPath(this.size)) {
            this.nodeAt((NodeIndex)n).node = null;
        }
        this.clearHashPath(index);
    }

    private OptionalNode nodeAt(LeafIndex n) {
        return this.nodeAt(new NodeIndex(n));
    }

    OptionalNode nodeAt(NodeIndex n) {
        long width = this.size.width();
        if (n.value() >= width) {
            throw new InvalidParameterException("Node index not in tree");
        }
        if (n.value() >= (long)this.nodes.size()) {
            return OptionalNode.blankNode();
        }
        return this.nodes.get((int)n.value());
    }

    public void truncate() {
        long w = this.size.width();
        if (this.size.leafCount() == 0L) {
            return;
        }
        LeafIndex index = new LeafIndex((int)this.size.leafCount() - 1);
        while (index.value > 0 && this.nodeAt(index).isBlank()) {
            this.clearHashPath(index);
            --index.value;
        }
        if (this.nodeAt(index).isBlank()) {
            this.nodes.clear();
            return;
        }
        while (this.size.leafCount() / 2L > (long)index.value) {
            this.nodes.subList(this.nodes.size() / 2, this.nodes.size()).clear();
            this.size = TreeSize.forLeaves(this.size.leafCount() / 2L);
        }
    }

    public void setHashAll() throws IOException {
        NodeIndex r = NodeIndex.root(this.size);
        this.getHash(r);
    }

    private byte[][] parentHashes(LeafIndex from, FilteredDirectPath fdp, List<UpdatePathNode> nodes) throws Exception {
        NodeIndex fromNode = new NodeIndex(from);
        FilteredDirectPath dp = fdp.clone();
        dp.parents.remove(dp.parents.size() - 1);
        dp.resolutions.remove(dp.resolutions.size() - 1);
        if (!fromNode.equals(NodeIndex.root(this.size))) {
            dp.parents.add(0, fromNode);
            dp.resolutions.add(0, new ArrayList());
        }
        if (dp.parents.size() != nodes.size()) {
            throw new Exception("Malformed UpdatePath");
        }
        NodeIndex last = NodeIndex.root(this.size);
        byte[] lastHash = new byte[]{};
        byte[][] ph = new byte[dp.parents.size()][];
        for (int i = dp.parents.size() - 1; i >= 0; --i) {
            NodeIndex n = dp.parents.get(i);
            NodeIndex s = n.sibling(last);
            ParentNode parentNode = new ParentNode(nodes.get(i).getEncryptionKey(), lastHash, new ArrayList<LeafIndex>());
            lastHash = this.getParentHash(parentNode, s);
            ph[i] = lastHash;
            last = n;
        }
        return ph;
    }

    private byte[] getParentHash(ParentNode parent, NodeIndex cpChild) throws Exception {
        if (!this.hashes.containsKey(cpChild)) {
            throw new Exception("Child hash not set");
        }
        ParentHashInput hashInput = new ParentHashInput(parent.encryptionKey, parent.parentHash, this.hashes.get(cpChild));
        return this.suite.hash(MLSOutputStream.encode(hashInput));
    }

    public boolean verifyParentHash(LeafIndex from, UpdatePath path) throws Exception {
        FilteredDirectPath fdp = this.getFilteredDirectPath(new NodeIndex(from));
        byte[][] hashChain = this.parentHashes(from, fdp, path.getNodes());
        if (hashChain.length == 0) {
            return path.getLeafNode().leaf_node_source != LeafNodeSource.COMMIT;
        }
        return Arrays.equals(path.getLeafNode().parent_hash, hashChain[0]);
    }

    public boolean verifyParentHash() throws IOException {
        long width = this.size.width();
        long height = NodeIndex.root(this.size).level();
        int level = 1;
        while ((long)level <= height) {
            int start;
            long stride = 2L << level;
            int p = start = (int)((stride >>> 1) - 1L);
            while ((long)p < width) {
                NodeIndex pIndex = new NodeIndex(p);
                if (!this.nodeAt(pIndex).isBlank()) {
                    NodeIndex l = pIndex.left();
                    NodeIndex r = pIndex.right();
                    byte[] lh = this.originalParentHash(pIndex, r);
                    byte[] rh = this.originalParentHash(pIndex, l);
                    if (!this.hasParentHash(l, lh) && !this.hasParentHash(r, rh)) {
                        this.dump();
                        return false;
                    }
                }
                p = (int)((long)p + stride);
            }
            ++level;
        }
        return true;
    }

    private boolean hasParentHash(NodeIndex child, byte[] targetParentHash) {
        ArrayList<NodeIndex> res = this.resolve(child);
        for (NodeIndex n : res) {
            if (!Arrays.equals(this.nodeAt((NodeIndex)n).node.getParentHash(), targetParentHash)) continue;
            return true;
        }
        return false;
    }

    private byte[] originalTreeHash(NodeIndex index, List<LeafIndex> parentExcept) throws IOException {
        byte[] hash;
        boolean haveLocalChanges;
        ArrayList<LeafIndex> except = new ArrayList<LeafIndex>();
        for (LeafIndex i : parentExcept) {
            NodeIndex n = new NodeIndex(i);
            if (!n.isBelow(index)) continue;
            except.add(i);
        }
        boolean bl = haveLocalChanges = !except.isEmpty();
        if (!haveLocalChanges) {
            return this.hashes.get(index);
        }
        if (this.treeHashCache.containsKey(index) && this.exceptCache.get(index).intValue() == except.size()) {
            return this.treeHashCache.get(index);
        }
        if (index.isLeaf()) {
            LeafNodeHashInput leafHashInput = new LeafNodeHashInput(new LeafIndex(index), null);
            hash = this.suite.hash(MLSOutputStream.encode(TreeHashInput.forLeafNode(leafHashInput)));
        } else {
            ParentNodeHashInput parentHashInput = new ParentNodeHashInput(null, this.originalTreeHash(index.left(), except), this.originalTreeHash(index.right(), except));
            if (!this.nodeAt(index).isBlank()) {
                parentHashInput.parentNode = this.nodeAt(index).getParentNode();
                ArrayList<LeafIndex> unmergedOriginal = new ArrayList<LeafIndex>(parentHashInput.parentNode.unmerged_leaves);
                parentHashInput.parentNode.unmerged_leaves.removeAll(except);
                hash = this.suite.hash(MLSOutputStream.encode(TreeHashInput.forParentNode(parentHashInput)));
                parentHashInput.parentNode.unmerged_leaves = unmergedOriginal;
            } else {
                hash = this.suite.hash(MLSOutputStream.encode(TreeHashInput.forParentNode(parentHashInput)));
            }
        }
        this.treeHashCache.put(index, hash);
        this.exceptCache.put(index, except.size());
        return hash;
    }

    private byte[] originalParentHash(NodeIndex parent, NodeIndex sibling) throws IOException {
        ParentNode parentNode = this.nodeAt(parent).getParentNode();
        byte[] siblingHash = this.originalTreeHash(sibling, parentNode.unmerged_leaves);
        return this.suite.hash(MLSOutputStream.encode(new ParentHashInput(parentNode.encryptionKey, parentNode.parentHash, siblingHash)));
    }

    public byte[] getHash(NodeIndex index) throws IOException {
        byte[] hashInput;
        MLSOutputStream.Writable input;
        if (this.hashes.containsKey(index)) {
            return this.hashes.get(index);
        }
        OptionalNode node = this.nodeAt(index);
        if (index.level() == 0L) {
            input = new LeafNodeHashInput(new LeafIndex(index), null);
            if (!node.isBlank()) {
                input.leafNode = node.getLeafNode();
            }
            hashInput = MLSOutputStream.encode(TreeHashInput.forLeafNode(input));
        } else {
            input = new ParentNodeHashInput(null, this.getHash(index.left()), this.getHash(index.right()));
            if (!node.isBlank()) {
                ((ParentNodeHashInput)input).parentNode = node.getParentNode();
            }
            hashInput = MLSOutputStream.encode(TreeHashInput.forParentNode((ParentNodeHashInput)input));
        }
        byte[] hash = this.suite.hash(hashInput);
        this.hashes.put(index, hash);
        return this.hashes.get(index);
    }
}

