/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.mls.TreeKEM;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.mls.TreeKEM.FilteredDirectPath;
import org.bouncycastle.mls.TreeKEM.LeafIndex;
import org.bouncycastle.mls.TreeKEM.Node;
import org.bouncycastle.mls.TreeKEM.NodeIndex;
import org.bouncycastle.mls.TreeKEM.TreeKEMPublicKey;
import org.bouncycastle.mls.TreeKEM.Utils;
import org.bouncycastle.mls.TreeSize;
import org.bouncycastle.mls.codec.HPKECiphertext;
import org.bouncycastle.mls.codec.UpdatePath;
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.crypto.Secret;
import org.bouncycastle.util.Strings;
import org.bouncycastle.util.encoders.Hex;

public class TreeKEMPrivateKey {
    MlsCipherSuite suite;
    LeafIndex index;
    Secret updateSecret;
    Map<NodeIndex, Secret> pathSecrets;
    Map<NodeIndex, AsymmetricCipherKeyPair> privateKeyCache;

    public Secret getUpdateSecret() {
        return this.updateSecret;
    }

    public void insertPathSecret(NodeIndex index, Secret secret) {
        this.pathSecrets.put(index, secret);
    }

    public void insertPrivateKey(NodeIndex index, AsymmetricCipherKeyPair keyPair) {
        this.privateKeyCache.put(index, keyPair);
    }

    public TreeKEMPrivateKey(MlsCipherSuite suite, LeafIndex index) {
        this.suite = suite;
        this.index = index;
        this.pathSecrets = new HashMap<NodeIndex, Secret>();
        this.privateKeyCache = new HashMap<NodeIndex, AsymmetricCipherKeyPair>();
    }

    public TreeKEMPrivateKey copy() {
        TreeKEMPrivateKey clone = new TreeKEMPrivateKey(this.suite, this.index);
        clone.pathSecrets.putAll(this.pathSecrets);
        clone.privateKeyCache.putAll(this.privateKeyCache);
        return clone;
    }

    public static TreeKEMPrivateKey solo(MlsCipherSuite suite, LeafIndex index, AsymmetricCipherKeyPair leafPriv) {
        TreeKEMPrivateKey priv = new TreeKEMPrivateKey(suite, index);
        priv.privateKeyCache.put(new NodeIndex(index), leafPriv);
        return priv;
    }

    public static TreeKEMPrivateKey create(TreeKEMPublicKey pub, LeafIndex from, Secret leafSecret) throws Exception {
        TreeKEMPrivateKey priv = new TreeKEMPrivateKey(pub.suite, from);
        priv.implant(pub, new NodeIndex(from), leafSecret);
        return priv;
    }

    public static TreeKEMPrivateKey joiner(TreeKEMPublicKey pub, LeafIndex index, AsymmetricCipherKeyPair leafPriv, NodeIndex intersect, Secret pathSecret) throws Exception {
        TreeKEMPrivateKey priv = new TreeKEMPrivateKey(pub.suite, index);
        priv.privateKeyCache.put(new NodeIndex(index), leafPriv);
        if (pathSecret != null) {
            priv.implant(pub, intersect, pathSecret);
        }
        return priv;
    }

    public String dump() throws IOException {
        StringBuilder sb = new StringBuilder();
        for (NodeIndex node : this.pathSecrets.keySet()) {
            this.setPrivateKey(node, true);
        }
        sb.append("Tree (priv)").append(Strings.lineSeparator());
        sb.append("  Index: ").append(new NodeIndex(this.index).value()).append(Strings.lineSeparator());
        sb.append("  Secrets: ").append(Strings.lineSeparator());
        for (NodeIndex n : this.pathSecrets.keySet()) {
            Secret pathSecret = this.pathSecrets.get(n);
            Secret nodeSecret = pathSecret.deriveSecret(this.suite, "node");
            AsymmetricCipherKeyPair sk = this.suite.getHPKE().deriveKeyPair(nodeSecret.value());
            sb.append("    ").append(n.value()).append(" => ").append(Hex.toHexString((byte[])pathSecret.value(), (int)0, (int)4)).append(" => ").append(Hex.toHexString((byte[])this.suite.getHPKE().serializePublicKey(sk.getPublic()), (int)0, (int)4)).append(Strings.lineSeparator());
        }
        sb.append("  Cached key pairs: ").append(Strings.lineSeparator());
        for (NodeIndex n : this.privateKeyCache.keySet()) {
            AsymmetricCipherKeyPair sk = this.privateKeyCache.get(n);
            sb.append("    ").append(n.value()).append(" => ").append(Hex.toHexString((byte[])this.suite.getHPKE().serializePublicKey(sk.getPublic()), (int)0, (int)4)).append(Strings.lineSeparator());
        }
        return sb.toString();
    }

    public void truncate(TreeSize size) {
        NodeIndex ni = new NodeIndex(new LeafIndex((int)(size.leafCount() - 1L)));
        ArrayList<NodeIndex> toRemove = new ArrayList<NodeIndex>();
        for (NodeIndex n : this.pathSecrets.keySet()) {
            if (n.value() <= ni.value()) continue;
            toRemove.add(n);
        }
        for (NodeIndex n : toRemove) {
            this.pathSecrets.remove(n);
            this.privateKeyCache.remove(n);
        }
    }

    public void setLeafKey(byte[] leafSkBytes) {
        NodeIndex n = new NodeIndex(this.index);
        this.pathSecrets.remove(n);
        AsymmetricCipherKeyPair leafSk = this.suite.getHPKE().deserializePrivateKey(leafSkBytes, null);
        this.privateKeyCache.put(n, leafSk);
    }

    public void decap(LeafIndex from, TreeKEMPublicKey pub, byte[] context, UpdatePath path, List<LeafIndex> except) throws Exception {
        NodeIndex ni = new NodeIndex(this.index);
        FilteredDirectPath dp = pub.getFilteredDirectPath(new NodeIndex(from));
        if (dp.parents.size() != path.getNodes().size()) {
            throw new Exception("Malformed direct path");
        }
        int dpi = 0;
        NodeIndex overlapNode = null;
        ArrayList<NodeIndex> res = new ArrayList<NodeIndex>();
        for (dpi = 0; dpi < dp.parents.size(); ++dpi) {
            if (!ni.isBelow(dp.parents.get(dpi))) continue;
            overlapNode = dp.parents.get(dpi);
            res = dp.resolutions.get(dpi);
            break;
        }
        if (dpi == dp.parents.size()) {
            throw new Exception("No overlap in path");
        }
        Utils.removeLeaves(res, except);
        if (res.size() != path.getNodes().get(dpi).getEncryptedPathSecret().size()) {
            throw new Exception("Malformed direct path node");
        }
        int resi = 0;
        for (resi = 0; resi < res.size() && !this.havePrivateKey(res.get(resi)); ++resi) {
        }
        if (resi == res.size()) {
            throw new Exception("No private key to decrypt path secret");
        }
        AsymmetricCipherKeyPair priv = this.setPrivateKey(res.get(resi), false);
        HPKECiphertext ct = path.getNodes().get(dpi).getEncryptedPathSecret().get(resi);
        Secret pathSecret = new Secret(this.suite.decryptWithLabel(this.suite.getHPKE().serializePrivateKey(priv.getPrivate()), "UpdatePathNode", context, ct.getKemOutput(), ct.getCiphertext()));
        this.implant(pub, overlapNode, pathSecret);
        if (!this.consistent(pub)) {
            throw new Exception("TreeKEMPublicKey inconsistant with TreeKEMPrivateKey");
        }
    }

    private boolean havePrivateKey(NodeIndex n) {
        return this.pathSecrets.containsKey(n) || this.privateKeyCache.containsKey(n);
    }

    public final boolean consistent(TreeKEMPublicKey other) throws IOException {
        if (this.suite.getSuiteID() != other.suite.getSuiteID()) {
            return false;
        }
        for (NodeIndex node : this.pathSecrets.keySet()) {
            this.setPrivateKey(node, true);
        }
        for (NodeIndex key : this.privateKeyCache.keySet()) {
            Node optNode = other.nodeAt((NodeIndex)key).node;
            if (optNode == null) continue;
            byte[] pub = optNode.getPublicKey();
            AsymmetricCipherKeyPair priv = this.privateKeyCache.get(key);
            if (Arrays.equals(pub, this.suite.getHPKE().serializePublicKey(priv.getPublic()))) continue;
            return false;
        }
        return true;
    }

    protected AsymmetricCipherKeyPair setPrivateKey(NodeIndex n, boolean isConst) throws IOException {
        AsymmetricCipherKeyPair priv = this.getPrivateKey(n);
        if (priv != null && !isConst) {
            this.privateKeyCache.put(n, priv);
        }
        return priv;
    }

    private AsymmetricCipherKeyPair getPrivateKey(NodeIndex n) throws IOException {
        if (this.privateKeyCache.containsKey(n)) {
            return this.privateKeyCache.get(n);
        }
        if (!this.pathSecrets.containsKey(n)) {
            return null;
        }
        Secret nodeSecret = this.pathSecrets.get(n).deriveSecret(this.suite, "node");
        return this.suite.getHPKE().deriveKeyPair(nodeSecret.value());
    }

    private void implant(TreeKEMPublicKey pub, NodeIndex start, Secret pathSecret) throws Exception {
        FilteredDirectPath fdp = pub.getFilteredDirectPath(start);
        Secret secret = new Secret(pathSecret.value());
        this.pathSecrets.put(start, secret);
        this.privateKeyCache.remove(start);
        for (NodeIndex n : fdp.parents) {
            secret = secret.deriveSecret(pub.suite, "path");
            this.pathSecrets.put(n, secret);
            this.privateKeyCache.remove(n);
        }
        this.updateSecret = secret.deriveSecret(pub.suite, "path");
    }

    public Secret getSharedPathSecret(LeafIndex to) {
        NodeIndex n = this.index.commonAncestor(to);
        if (!this.pathSecrets.containsKey(n)) {
            return new Secret(new byte[0]);
        }
        return this.pathSecrets.get(n);
    }
}

