/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.mls.client;

import com.google.protobuf.ByteString;
import com.google.protobuf.MessageOrBuilder;
import io.grpc.Status;
import io.grpc.StatusException;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import mls_client.MLSClientGrpc;
import mls_client.MlsClient;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.mls.TreeKEM.LeafIndex;
import org.bouncycastle.mls.TreeKEM.LeafNode;
import org.bouncycastle.mls.TreeKEM.LifeTime;
import org.bouncycastle.mls.TreeKEM.TreeKEMPublicKey;
import org.bouncycastle.mls.client.KeyPackageWithSecrets;
import org.bouncycastle.mls.codec.Capabilities;
import org.bouncycastle.mls.codec.Credential;
import org.bouncycastle.mls.codec.Extension;
import org.bouncycastle.mls.codec.ExtensionType;
import org.bouncycastle.mls.codec.ExternalSender;
import org.bouncycastle.mls.codec.GroupInfo;
import org.bouncycastle.mls.codec.KeyPackage;
import org.bouncycastle.mls.codec.MLSInputStream;
import org.bouncycastle.mls.codec.MLSMessage;
import org.bouncycastle.mls.codec.MLSOutputStream;
import org.bouncycastle.mls.codec.PreSharedKeyID;
import org.bouncycastle.mls.codec.Proposal;
import org.bouncycastle.mls.codec.ProtocolVersion;
import org.bouncycastle.mls.codec.ResumptionPSKUsage;
import org.bouncycastle.mls.codec.Welcome;
import org.bouncycastle.mls.codec.WireFormat;
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.crypto.Secret;
import org.bouncycastle.mls.protocol.Group;
import org.bouncycastle.util.Pack;

public class MLSClientImpl
extends MLSClientGrpc.MLSClientImplBase {
    Map<Integer, CachedGroup> groupCache = new HashMap<Integer, CachedGroup>();
    Map<Integer, CachedJoin> joinCache = new HashMap<Integer, CachedJoin>();
    Map<Integer, CachedReinit> reinitCache = new HashMap<Integer, CachedReinit>();
    Map<Integer, byte[]> signerCache = new HashMap<Integer, byte[]>();

    private static String getCallerMethodName() {
        StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
        return stackTrace[3].getMethodName();
    }

    private static <T> void catchWrap(Function f, StreamObserver<T> observer) {
        try {
            f.run();
        }
        catch (Exception e) {
            observer.onError((Throwable)Status.INTERNAL.withDescription(e.getMessage()).asException());
        }
    }

    private <T> void stateWrap(FunctionWithState f, MessageOrBuilder request, StreamObserver<T> observer) {
        int stateID = (Integer)request.getField(request.getDescriptorForType().findFieldByName("state_id"));
        CachedGroup group = this.loadGroup(stateID);
        if (group == null) {
            observer.onError((Throwable)Status.NOT_FOUND.withDescription("Unknown state").asException());
        }
        try {
            f.run(group);
        }
        catch (Exception e) {
            observer.onError((Throwable)Status.INTERNAL.withDescription(e.getMessage()).asException());
        }
    }

    private int storeGroup(Group group, boolean encryptHandshake) {
        int groupID = 0x7FFFFFF & Pack.littleEndianToInt((byte[])group.getEpochAuthenticator(), (int)0);
        CachedGroup entry = new CachedGroup(group, encryptHandshake);
        this.groupCache.put(groupID += group.getIndex().value(), entry);
        return groupID;
    }

    private CachedGroup loadGroup(int stateID) {
        if (!this.groupCache.containsKey(stateID)) {
            return null;
        }
        return this.groupCache.get(stateID);
    }

    private int storeJoin(KeyPackageWithSecrets kpSecrets) throws IOException {
        MlsCipherSuite suite = kpSecrets.keyPackage.getSuite();
        int joinID = 0x7FFFFFF & Pack.littleEndianToInt((byte[])suite.refHash(MLSOutputStream.encode(kpSecrets.keyPackage), "MLS 1.0 KeyPackage Reference"), (int)0);
        CachedJoin entry = new CachedJoin(kpSecrets);
        this.joinCache.put(joinID, entry);
        return joinID;
    }

    private CachedJoin loadJoin(int joinID) {
        if (!this.joinCache.containsKey(joinID)) {
            return null;
        }
        return this.joinCache.get(joinID);
    }

    private int storeSigner(byte[] sigPriv) {
        int signerID = 0x7FFFFFF & Pack.littleEndianToInt((byte[])sigPriv, (int)0);
        this.signerCache.put(signerID, sigPriv);
        return signerID;
    }

    private byte[] loadSigner(int signerID) {
        if (!this.signerCache.containsKey(signerID)) {
            return null;
        }
        return this.signerCache.get(signerID);
    }

    private int storeReinit(KeyPackageWithSecrets kpSk, Group.Tombstone tombstone, boolean encryptHandshake) throws IOException {
        MlsCipherSuite suite = kpSk.keyPackage.getSuite();
        int reinitID = 0x7FFFFFF & Pack.littleEndianToInt((byte[])suite.refHash(MLSOutputStream.encode(kpSk.keyPackage), "MLS 1.0 KeyPackage Reference"), (int)0);
        this.reinitCache.put(reinitID, new CachedReinit(kpSk, tombstone, encryptHandshake));
        return reinitID;
    }

    private CachedReinit loadReinit(int reinitID) {
        if (!this.reinitCache.containsKey(reinitID)) {
            return null;
        }
        return this.reinitCache.get(reinitID);
    }

    private CachedGroup findState(byte[] groupID, long epoch) {
        CachedGroup result = null;
        for (int id : this.groupCache.keySet()) {
            CachedGroup cached = this.groupCache.get(id);
            if (cached == null || !Arrays.equals(cached.group.getGroupID(), groupID) || cached.group.getEpoch() != epoch) continue;
            result = cached;
        }
        return result;
    }

    private KeyPackageWithSecrets newKeyPackage(MlsCipherSuite suite, byte[] identity) throws Exception {
        AsymmetricCipherKeyPair initKeyPair = suite.getHPKE().generatePrivateKey();
        AsymmetricCipherKeyPair encryptionKeyPair = suite.getHPKE().generatePrivateKey();
        AsymmetricCipherKeyPair sigKeyPair = suite.generateSignatureKeyPair();
        Credential cred = Credential.forBasic(identity);
        LeafNode leafNode = new LeafNode(suite, suite.getHPKE().serializePublicKey(encryptionKeyPair.getPublic()), suite.serializeSignaturePublicKey(sigKeyPair.getPublic()), cred, new Capabilities(), new LifeTime(), new ArrayList<Extension>(), suite.serializeSignaturePrivateKey(sigKeyPair.getPrivate()));
        KeyPackage kp = new KeyPackage(suite, suite.getHPKE().serializePublicKey(initKeyPair.getPublic()), leafNode, new ArrayList<Extension>(), suite.serializeSignaturePrivateKey(sigKeyPair.getPrivate()));
        return new KeyPackageWithSecrets(initKeyPair, encryptionKeyPair, sigKeyPair, kp);
    }

    private LeafIndex findMember(TreeKEMPublicKey tree, byte[] id) throws Exception {
        int i = 0;
        while ((long)i < tree.getSize().leafCount()) {
            LeafIndex index = new LeafIndex(i);
            LeafNode leaf = tree.getLeafNode(index);
            if (leaf != null && Arrays.equals(leaf.getCredential().getIdentity(), id)) {
                return index;
            }
            ++i;
        }
        throw new Exception("Unknown member identity");
    }

    private Proposal proposalFromDescription(MlsCipherSuite suite, byte[] groupID, TreeKEMPublicKey tree, MlsClient.ProposalDescription desc) throws Exception {
        SecureRandom random = new SecureRandom();
        switch (desc.getProposalType().toStringUtf8()) {
            case "add": {
                MLSMessage kp = (MLSMessage)MLSInputStream.decode(desc.getKeyPackage().toByteArray(), MLSMessage.class);
                return Proposal.add(kp.keyPackage);
            }
            case "remove": {
                LeafIndex removedIndex = this.findMember(tree, desc.getRemovedId().toByteArray());
                return Proposal.remove(removedIndex);
            }
            case "externalPSK": {
                byte[] externalPskID = desc.getPskId().toByteArray();
                byte[] extNonce = new byte[suite.getKDF().getHashLength()];
                random.nextBytes(extNonce);
                PreSharedKeyID extPskID = PreSharedKeyID.external(externalPskID, extNonce);
                return Proposal.preSharedKey(extPskID);
            }
            case "resumptionPSK": {
                long epoch = desc.getEpochId();
                byte[] resNonce = new byte[suite.getKDF().getHashLength()];
                PreSharedKeyID resPskID = PreSharedKeyID.resumption(ResumptionPSKUsage.APPLICATION, groupID, epoch, resNonce);
                return Proposal.preSharedKey(resPskID);
            }
            case "groupContextExtensions": 
            case "reinit": {
                ArrayList<Extension> extList = new ArrayList<Extension>();
                for (int i = 0; i < desc.getExtensionsCount(); ++i) {
                    Extension ext = new Extension(desc.getExtensions(i).getExtensionType(), desc.getExtensions(i).getExtensionData().toByteArray());
                    extList.add(ext);
                }
                if (desc.getProposalType().toStringUtf8().equals("reinit")) {
                    return Proposal.reInit(desc.getGroupId().toByteArray(), ProtocolVersion.mls10, MlsCipherSuite.getSuite((short)desc.getCipherSuite()), extList);
                }
                return Proposal.groupContextExtensions(extList);
            }
        }
        throw new IllegalStateException("Unknown proposal-by-value type: " + desc.getProposalType().toString());
    }

    private void removeGroup(int stateID) {
        this.groupCache.remove(stateID);
    }

    private void nameImpl(MlsClient.NameRequest request, StreamObserver<MlsClient.NameResponse> responseObserver) {
        MlsClient.NameResponse response = MlsClient.NameResponse.newBuilder().setName("BouncyCastle").build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void name(final MlsClient.NameRequest request, final StreamObserver<MlsClient.NameResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.nameImpl(request, (StreamObserver<MlsClient.NameResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void supportedCiphersuitesImpl(MlsClient.SupportedCiphersuitesRequest request, StreamObserver<MlsClient.SupportedCiphersuitesResponse> responseObserver) {
        MlsClient.SupportedCiphersuitesResponse.Builder builder = MlsClient.SupportedCiphersuitesResponse.newBuilder().clearCiphersuites();
        for (short id : MlsCipherSuite.ALL_SUPPORTED_SUITES) {
            builder.addCiphersuites(id);
        }
        MlsClient.SupportedCiphersuitesResponse response = builder.build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void supportedCiphersuites(final MlsClient.SupportedCiphersuitesRequest request, final StreamObserver<MlsClient.SupportedCiphersuitesResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.supportedCiphersuitesImpl(request, (StreamObserver<MlsClient.SupportedCiphersuitesResponse>)responseObserver);
            }
        }, responseObserver);
    }

    public void createGroupImpl(MlsClient.CreateGroupRequest request, StreamObserver<MlsClient.CreateGroupResponse> responseObserver) throws Exception {
        byte[] groupID = request.getGroupId().toByteArray();
        MlsCipherSuite suite = MlsCipherSuite.getSuite((short)request.getCipherSuite());
        byte[] identity = request.getIdentity().toByteArray();
        AsymmetricCipherKeyPair leafKeyPair = suite.getHPKE().generatePrivateKey();
        AsymmetricCipherKeyPair sigKeyPair = suite.generateSignatureKeyPair();
        Credential cred = Credential.forBasic(identity);
        LeafNode leafNode = new LeafNode(suite, suite.getHPKE().serializePublicKey(leafKeyPair.getPublic()), suite.serializeSignaturePublicKey(sigKeyPair.getPublic()), cred, new Capabilities(), new LifeTime(), new ArrayList<Extension>(), suite.serializeSignaturePrivateKey(sigKeyPair.getPrivate()));
        Group group = new Group(groupID, suite, leafKeyPair, suite.serializeSignaturePrivateKey(sigKeyPair.getPrivate()), leafNode.copy(leafNode.getEncryptionKey()), new ArrayList<Extension>());
        int stateId = this.storeGroup(group, request.getEncryptHandshake());
        MlsClient.CreateGroupResponse response = MlsClient.CreateGroupResponse.newBuilder().setStateId(stateId).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void createGroup(final MlsClient.CreateGroupRequest request, final StreamObserver<MlsClient.CreateGroupResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.createGroupImpl(request, (StreamObserver<MlsClient.CreateGroupResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void createKeyPackageImpl(MlsClient.CreateKeyPackageRequest request, StreamObserver<MlsClient.CreateKeyPackageResponse> responseObserver) throws Exception {
        MlsCipherSuite suite = MlsCipherSuite.getSuite((short)request.getCipherSuite());
        byte[] identity = request.getIdentity().toByteArray();
        KeyPackageWithSecrets kpSecrets = this.newKeyPackage(suite, identity);
        int joinID = this.storeJoin(kpSecrets);
        MlsClient.CreateKeyPackageResponse response = MlsClient.CreateKeyPackageResponse.newBuilder().setInitPriv(ByteString.copyFrom((byte[])suite.getHPKE().serializePrivateKey(kpSecrets.initKeyPair.getPrivate()))).setEncryptionPriv(ByteString.copyFrom((byte[])suite.getHPKE().serializePrivateKey(kpSecrets.encryptionKeyPair.getPrivate()))).setSignaturePriv(ByteString.copyFrom((byte[])suite.serializeSignaturePrivateKey(kpSecrets.signatureKeyPair.getPrivate()))).setKeyPackage(ByteString.copyFrom((byte[])MLSOutputStream.encode(MLSMessage.keyPackage(kpSecrets.keyPackage)))).setTransactionId(joinID).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void createKeyPackage(final MlsClient.CreateKeyPackageRequest request, final StreamObserver<MlsClient.CreateKeyPackageResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.createKeyPackageImpl(request, (StreamObserver<MlsClient.CreateKeyPackageResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void joinGroupImpl(MlsClient.JoinGroupRequest request, StreamObserver<MlsClient.JoinGroupResponse> responseObserver) throws Exception {
        CachedJoin join = this.loadJoin(request.getTransactionId());
        if (join == null) {
            throw new Exception("Unknown transaction ID");
        }
        MLSMessage welcomeMsg = (MLSMessage)MLSInputStream.decode(request.getWelcome().toByteArray(), MLSMessage.class);
        Welcome welcome = welcomeMsg.welcome;
        byte[] ratchetTreeBytes = request.getRatchetTree().toByteArray();
        TreeKEMPublicKey ratchetTree = null;
        if (ratchetTreeBytes.length > 0) {
            ratchetTree = (TreeKEMPublicKey)MLSInputStream.decode(ratchetTreeBytes, TreeKEMPublicKey.class);
            ratchetTree.setSuite(welcomeMsg.getCipherSuite());
        }
        MlsCipherSuite suite = welcome.getSuite();
        Group group = new Group(suite.getHPKE().serializePrivateKey(join.kpSecrets.initKeyPair.getPrivate()), join.kpSecrets.encryptionKeyPair, suite.serializeSignaturePrivateKey(join.kpSecrets.signatureKeyPair.getPrivate()), join.kpSecrets.keyPackage, welcome, ratchetTree, join.externalPsks, new HashMap<Group.EpochRef, byte[]>());
        byte[] epochAuthenticator = group.getEpochAuthenticator();
        int stateID = this.storeGroup(group, request.getEncryptHandshake());
        MlsClient.JoinGroupResponse response = MlsClient.JoinGroupResponse.newBuilder().setStateId(stateID).setEpochAuthenticator(ByteString.copyFrom((byte[])epochAuthenticator)).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void joinGroup(final MlsClient.JoinGroupRequest request, final StreamObserver<MlsClient.JoinGroupResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.joinGroupImpl(request, (StreamObserver<MlsClient.JoinGroupResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void externalJoinImpl(MlsClient.ExternalJoinRequest request, StreamObserver<MlsClient.ExternalJoinResponse> responseObserver) throws Exception {
        int i;
        MLSMessage groupInfoMsg = (MLSMessage)MLSInputStream.decode(request.getGroupInfo().toByteArray(), MLSMessage.class);
        GroupInfo groupInfo = groupInfoMsg.groupInfo;
        MlsCipherSuite suite = groupInfo.getSuite();
        AsymmetricCipherKeyPair initKeyPair = suite.getHPKE().generatePrivateKey();
        AsymmetricCipherKeyPair leafKeyPair = suite.getHPKE().generatePrivateKey();
        AsymmetricCipherKeyPair sigKeyPair = suite.generateSignatureKeyPair();
        byte[] identity = request.getIdentity().toByteArray();
        Credential cred = Credential.forBasic(identity);
        LeafNode leafNode = new LeafNode(suite, suite.getHPKE().serializePublicKey(leafKeyPair.getPublic()), suite.serializeSignaturePublicKey(sigKeyPair.getPublic()), cred, new Capabilities(), new LifeTime(), new ArrayList<Extension>(), suite.serializeSignaturePrivateKey(sigKeyPair.getPrivate()));
        KeyPackage kp = new KeyPackage(suite, suite.getHPKE().serializePublicKey(initKeyPair.getPublic()), leafNode, new ArrayList<Extension>(), suite.serializeSignaturePrivateKey(sigKeyPair.getPrivate()));
        byte[] ratchetTreeBytes = request.getRatchetTree().toByteArray();
        TreeKEMPublicKey ratchetTree = null;
        if (ratchetTreeBytes.length > 0) {
            ratchetTree = (TreeKEMPublicKey)MLSInputStream.decode(ratchetTreeBytes, TreeKEMPublicKey.class);
            ratchetTree.setSuite(suite);
        }
        LeafIndex removeIndex = null;
        boolean removePrior = request.getRemovePrior();
        if (removePrior) {
            Extension ext;
            TreeKEMPublicKey outTree = null;
            Iterator<Extension> iterator = groupInfo.getExtensions().iterator();
            while (iterator.hasNext() && (outTree = (ext = iterator.next()).getRatchetTree()) == null) {
            }
            if (ratchetTree != null) {
                outTree = TreeKEMPublicKey.clone(ratchetTree);
            } else if (outTree == null) {
                throw new Exception("No tree available");
            }
            i = 0;
            while ((long)i < outTree.getSize().leafCount()) {
                LeafIndex index = new LeafIndex(i);
                LeafNode leaf = outTree.getLeafNode(index);
                if (leaf != null && Arrays.equals(identity, leaf.getCredential().getIdentity())) {
                    removeIndex = index;
                }
                ++i;
            }
            if (removeIndex == null) {
                throw new Exception("Prior appearance not found");
            }
        }
        HashMap<Secret, byte[]> externalPSKs = new HashMap<Secret, byte[]>();
        for (i = 0; i < request.getPsksCount(); ++i) {
            MlsClient.PreSharedKey psk = request.getPsks(i);
            Secret pskID = new Secret(psk.getPskId().toByteArray());
            byte[] pskSecret = psk.getPskSecret().toByteArray();
            externalPSKs.put(pskID, pskSecret);
        }
        byte[] leafSecret = new byte[suite.getKDF().getHashLength()];
        SecureRandom random = new SecureRandom();
        random.nextBytes(leafSecret);
        Group.GroupWithMessage gwm = Group.externalJoin(new Secret(leafSecret), sigKeyPair, kp, groupInfo, ratchetTree, new Group.MessageOptions(false, new byte[0], 0), removeIndex, externalPSKs);
        int stateID = this.storeGroup(gwm.group, false);
        MlsClient.ExternalJoinResponse response = MlsClient.ExternalJoinResponse.newBuilder().setStateId(stateID).setCommit(ByteString.copyFrom((byte[])MLSOutputStream.encode(gwm.message))).setEpochAuthenticator(ByteString.copyFrom((byte[])gwm.group.getEpochAuthenticator())).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void externalJoin(final MlsClient.ExternalJoinRequest request, final StreamObserver<MlsClient.ExternalJoinResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.externalJoinImpl(request, (StreamObserver<MlsClient.ExternalJoinResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void groupInfoImpl(CachedGroup entry, MlsClient.GroupInfoRequest request, StreamObserver<MlsClient.GroupInfoResponse> responseObserver) throws Exception {
        boolean inlineTree = !request.getExternalTree();
        MLSMessage groupInfo = entry.group.getGroupInfo(inlineTree);
        MlsClient.GroupInfoResponse.Builder builder = MlsClient.GroupInfoResponse.newBuilder().setGroupInfo(ByteString.copyFrom((byte[])MLSOutputStream.encode(groupInfo)));
        if (!inlineTree) {
            builder.setRatchetTree(ByteString.copyFrom((byte[])MLSOutputStream.encode(entry.group.getTree())));
        }
        MlsClient.GroupInfoResponse response = builder.build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void groupInfo(final MlsClient.GroupInfoRequest request, final StreamObserver<MlsClient.GroupInfoResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.groupInfoImpl(group, request, (StreamObserver<MlsClient.GroupInfoResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void stateAuthImpl(CachedGroup entry, MlsClient.StateAuthRequest request, StreamObserver<MlsClient.StateAuthResponse> responseObserver) {
        MlsClient.StateAuthResponse response = MlsClient.StateAuthResponse.newBuilder().setStateAuthSecret(ByteString.copyFrom((byte[])entry.group.getEpochAuthenticator())).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void stateAuth(final MlsClient.StateAuthRequest request, final StreamObserver<MlsClient.StateAuthResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.stateAuthImpl(group, request, (StreamObserver<MlsClient.StateAuthResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void exportImpl(CachedGroup entry, MlsClient.ExportRequest request, StreamObserver<MlsClient.ExportResponse> responseObserver) throws IOException {
        String label = request.getLabel();
        byte[] context = request.getContext().toByteArray();
        int size = request.getKeyLength();
        byte[] secret = entry.group.getKeySchedule().MLSExporter(label, context, size);
        MlsClient.ExportResponse response = MlsClient.ExportResponse.newBuilder().setExportedSecret(ByteString.copyFrom((byte[])secret)).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void export(final MlsClient.ExportRequest request, final StreamObserver<MlsClient.ExportResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.exportImpl(group, request, (StreamObserver<MlsClient.ExportResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void protectImpl(CachedGroup entry, MlsClient.ProtectRequest request, StreamObserver<MlsClient.ProtectResponse> responseObserver) throws Exception {
        MLSMessage ct = entry.group.protect(request.getAuthenticatedData().toByteArray(), request.getPlaintext().toByteArray(), 0);
        MlsClient.ProtectResponse response = MlsClient.ProtectResponse.newBuilder().setCiphertext(ByteString.copyFrom((byte[])MLSOutputStream.encode(ct))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void protect(final MlsClient.ProtectRequest request, final StreamObserver<MlsClient.ProtectResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.protectImpl(group, request, (StreamObserver<MlsClient.ProtectResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void unprotectImpl(CachedGroup entry, MlsClient.UnprotectRequest request, StreamObserver<MlsClient.UnprotectResponse> responseObserver) throws Exception {
        MLSMessage ct = (MLSMessage)MLSInputStream.decode(request.getCiphertext().toByteArray(), MLSMessage.class);
        byte[] groupID = entry.group.getGroupID();
        long epoch = ct.getEpoch();
        Group group = entry.group;
        if (entry.group.getEpoch() != epoch) {
            CachedGroup cached = this.findState(groupID, epoch);
            if (cached == null) {
                throw new Exception("Unknown state for unprotect");
            }
            group = cached.group;
        }
        byte[][] authAndPt = group.unprotect(ct);
        byte[] aad = authAndPt[0];
        byte[] pt = authAndPt[1];
        MlsClient.UnprotectResponse response = MlsClient.UnprotectResponse.newBuilder().setAuthenticatedData(ByteString.copyFrom((byte[])aad)).setPlaintext(ByteString.copyFrom((byte[])pt)).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void unprotect(final MlsClient.UnprotectRequest request, final StreamObserver<MlsClient.UnprotectResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.unprotectImpl(group, request, (StreamObserver<MlsClient.UnprotectResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void storePSKImpl(MlsClient.StorePSKRequest request, StreamObserver<MlsClient.StorePSKResponse> responseObserver) throws StatusException {
        MlsClient.StorePSKResponse response = MlsClient.StorePSKResponse.newBuilder().build();
        int id = request.getStateOrTransactionId();
        Secret pskId = new Secret(request.getPskId().toByteArray());
        byte[] pskSecret = request.getPskSecret().toByteArray();
        CachedJoin join = this.loadJoin(id);
        if (join != null) {
            join.externalPsks.put(pskId, pskSecret);
            responseObserver.onNext((Object)response);
            responseObserver.onCompleted();
            return;
        }
        CachedGroup cached = this.loadGroup(id);
        if (cached == null) {
            throw Status.NOT_FOUND.withDescription("Unknown state").asException();
        }
        cached.group.insertExternalPsk(pskId, pskSecret);
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void storePSK(final MlsClient.StorePSKRequest request, final StreamObserver<MlsClient.StorePSKResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.storePSKImpl(request, (StreamObserver<MlsClient.StorePSKResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void addProposalImpl(CachedGroup entry, MlsClient.AddProposalRequest request, StreamObserver<MlsClient.ProposalResponse> responseObserver) throws Exception {
        MLSMessage keyPackage = (MLSMessage)MLSInputStream.decode(request.getKeyPackage().toByteArray(), MLSMessage.class);
        MLSMessage message = entry.group.add(keyPackage.keyPackage, entry.messageOptions);
        MlsClient.ProposalResponse response = MlsClient.ProposalResponse.newBuilder().setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(message))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void addProposal(final MlsClient.AddProposalRequest request, final StreamObserver<MlsClient.ProposalResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.addProposalImpl(group, request, (StreamObserver<MlsClient.ProposalResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void updateProposalImpl(CachedGroup entry, MlsClient.UpdateProposalRequest request, StreamObserver<MlsClient.ProposalResponse> responseObserver) throws Exception {
        AsymmetricCipherKeyPair leafSk = entry.group.getSuite().getHPKE().generatePrivateKey();
        Proposal update = entry.group.updateProposal(leafSk, new Group.LeafNodeOptions());
        MLSMessage message = entry.group.update(update, entry.messageOptions);
        MlsClient.ProposalResponse response = MlsClient.ProposalResponse.newBuilder().setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(message))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void updateProposal(final MlsClient.UpdateProposalRequest request, final StreamObserver<MlsClient.ProposalResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.updateProposalImpl(group, request, (StreamObserver<MlsClient.ProposalResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void removeProposalImpl(CachedGroup entry, MlsClient.RemoveProposalRequest request, StreamObserver<MlsClient.ProposalResponse> responseObserver) throws Exception {
        LeafIndex removedIndex = this.findMember(entry.group.getTree(), request.getRemovedId().toByteArray());
        MLSMessage message = entry.group.remove(removedIndex, entry.messageOptions);
        MlsClient.ProposalResponse response = MlsClient.ProposalResponse.newBuilder().setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(message))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void removeProposal(final MlsClient.RemoveProposalRequest request, final StreamObserver<MlsClient.ProposalResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.removeProposalImpl(group, request, (StreamObserver<MlsClient.ProposalResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void externalPSKProposalImpl(CachedGroup entry, MlsClient.ExternalPSKProposalRequest request, StreamObserver<MlsClient.ProposalResponse> responseObserver) throws Exception {
        byte[] pskID = request.getPskId().toByteArray();
        MLSMessage message = entry.group.preSharedKey(pskID, entry.messageOptions);
        MlsClient.ProposalResponse response = MlsClient.ProposalResponse.newBuilder().setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(message))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void externalPSKProposal(final MlsClient.ExternalPSKProposalRequest request, final StreamObserver<MlsClient.ProposalResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.externalPSKProposalImpl(group, request, (StreamObserver<MlsClient.ProposalResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void resumptionPSKProposalImpl(CachedGroup entry, MlsClient.ResumptionPSKProposalRequest request, StreamObserver<MlsClient.ProposalResponse> responseObserver) throws Exception {
        MLSMessage message = entry.group.preSharedKey(entry.group.getGroupID(), request.getEpochId(), entry.messageOptions);
        MlsClient.ProposalResponse response = MlsClient.ProposalResponse.newBuilder().setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(message))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void resumptionPSKProposal(final MlsClient.ResumptionPSKProposalRequest request, final StreamObserver<MlsClient.ProposalResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.resumptionPSKProposalImpl(group, request, (StreamObserver<MlsClient.ProposalResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void groupContextExtensionsProposalImpl(CachedGroup entry, MlsClient.GroupContextExtensionsProposalRequest request, StreamObserver<MlsClient.ProposalResponse> responseObserver) throws Exception {
        ArrayList<Extension> extList = new ArrayList<Extension>();
        for (int i = 0; i < request.getExtensionsCount(); ++i) {
            Extension ext = new Extension(request.getExtensions(i).getExtensionType(), request.getExtensions(i).getExtensionData().toByteArray());
            extList.add(ext);
        }
        MLSMessage message = entry.group.groupContextExtensions(extList, entry.messageOptions);
        MlsClient.ProposalResponse response = MlsClient.ProposalResponse.newBuilder().setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(message))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void groupContextExtensionsProposal(final MlsClient.GroupContextExtensionsProposalRequest request, final StreamObserver<MlsClient.ProposalResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.groupContextExtensionsProposalImpl(group, request, (StreamObserver<MlsClient.ProposalResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void commitImpl(CachedGroup entry, MlsClient.CommitRequest request, StreamObserver<MlsClient.CommitResponse> responseObserver) throws Exception {
        int byRefSize = request.getByReferenceCount();
        for (int i = 0; i < byRefSize; ++i) {
            byte[] msg = request.getByReference(i).toByteArray();
            Group shouldBeNull = entry.group.handle(msg, null);
            if (shouldBeNull == null) continue;
            throw new Exception("Commit included among proposals");
        }
        ArrayList<Proposal> byValue = new ArrayList<Proposal>();
        for (int i = 0; i < request.getByValueCount(); ++i) {
            MlsClient.ProposalDescription desc = request.getByValue(i);
            Proposal proposal = this.proposalFromDescription(entry.group.getSuite(), entry.group.getGroupID(), entry.group.getTree(), desc);
            byValue.add(proposal);
        }
        boolean forcePath = request.getForcePath();
        boolean inlineTree = !request.getExternalTree();
        SecureRandom random = new SecureRandom();
        byte[] leafSecret = new byte[entry.group.getSuite().getKDF().getHashLength()];
        random.nextBytes(leafSecret);
        Group.GroupWithMessage gwm = entry.group.commit(new Secret(leafSecret), new Group.CommitOptions(byValue, inlineTree, forcePath, null), entry.messageOptions, new Group.CommitParameters(0));
        byte[] commitBytes = MLSOutputStream.encode(gwm.message);
        gwm.message.wireFormat = WireFormat.mls_welcome;
        byte[] welcomeBytes = MLSOutputStream.encode(gwm.message);
        int nextID = this.storeGroup(gwm.group, entry.encryptHandshake);
        entry.pendingCommit = commitBytes;
        entry.pendingGroupID = nextID;
        MlsClient.CommitResponse.Builder builder = MlsClient.CommitResponse.newBuilder().setCommit(ByteString.copyFrom((byte[])commitBytes)).setWelcome(ByteString.copyFrom((byte[])welcomeBytes));
        if (!inlineTree) {
            builder.setRatchetTree(ByteString.copyFrom((byte[])MLSOutputStream.encode(gwm.group.getTree())));
        }
        responseObserver.onNext((Object)builder.build());
        responseObserver.onCompleted();
    }

    @Override
    public void commit(final MlsClient.CommitRequest request, final StreamObserver<MlsClient.CommitResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.commitImpl(group, request, (StreamObserver<MlsClient.CommitResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void handleCommitImpl(CachedGroup entry, MlsClient.HandleCommitRequest request, StreamObserver<MlsClient.HandleCommitResponse> responseObserver) throws Exception {
        byte[] commitBytes = request.getCommit().toByteArray();
        if (entry.pendingCommit != null && Arrays.equals(commitBytes, entry.pendingCommit)) {
            MlsClient.HandleCommitResponse response = MlsClient.HandleCommitResponse.newBuilder().setStateId(entry.pendingGroupID).build();
            entry.resetPending();
            responseObserver.onNext((Object)response);
            responseObserver.onCompleted();
            return;
        }
        if (entry.pendingGroupID != -1) {
            this.removeGroup(entry.pendingGroupID);
            entry.resetPending();
        }
        int proposalSize = request.getProposalCount();
        for (int i = 0; i < proposalSize; ++i) {
            byte[] messageBytes = request.getProposal(i).toByteArray();
            Group shouldBeNull = entry.group.handle(messageBytes, null);
            if (shouldBeNull == null) continue;
            throw new Exception("Commit included among proposals");
        }
        Group next = entry.group.handle(commitBytes, null);
        if (next == null) {
            throw new Exception("Commit failed to produce a new state");
        }
        byte[] epochAuthenticator = next.getEpochAuthenticator();
        int nextID = this.storeGroup(next, entry.encryptHandshake);
        MlsClient.HandleCommitResponse response = MlsClient.HandleCommitResponse.newBuilder().setStateId(nextID).setEpochAuthenticator(ByteString.copyFrom((byte[])epochAuthenticator)).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void handleCommit(final MlsClient.HandleCommitRequest request, final StreamObserver<MlsClient.HandleCommitResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.handleCommitImpl(group, request, (StreamObserver<MlsClient.HandleCommitResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void handlePendingCommitImpl(CachedGroup entry, MlsClient.HandlePendingCommitRequest request, StreamObserver<MlsClient.HandleCommitResponse> responseObserver) throws Exception {
        if (entry.pendingCommit == null || entry.pendingGroupID == -1) {
            throw new Exception("No pending commit to handle");
        }
        int nextID = entry.pendingGroupID;
        CachedGroup next = this.loadGroup(nextID);
        if (next == null) {
            throw new Exception("No Internal error: No state for next ID");
        }
        byte[] epochAuthenticator = next.group.getEpochAuthenticator();
        MlsClient.HandleCommitResponse response = MlsClient.HandleCommitResponse.newBuilder().setStateId(nextID).setEpochAuthenticator(ByteString.copyFrom((byte[])epochAuthenticator)).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void handlePendingCommit(final MlsClient.HandlePendingCommitRequest request, final StreamObserver<MlsClient.HandleCommitResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.handlePendingCommitImpl(group, request, (StreamObserver<MlsClient.HandleCommitResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void reInitProposalImpl(CachedGroup entry, MlsClient.ReInitProposalRequest request, StreamObserver<MlsClient.ProposalResponse> responseObserver) throws Exception {
        byte[] groupID = request.getGroupId().toByteArray();
        ProtocolVersion version = ProtocolVersion.mls10;
        MlsCipherSuite suite = MlsCipherSuite.getSuite((short)request.getCipherSuite());
        ArrayList<Extension> extList = new ArrayList<Extension>();
        for (int i = 0; i < request.getExtensionsCount(); ++i) {
            Extension ext = new Extension(request.getExtensions(i).getExtensionType(), request.getExtensions(i).getExtensionData().toByteArray());
            extList.add(ext);
        }
        MLSMessage message = entry.group.reinit(groupID, version, suite, extList, entry.messageOptions);
        MlsClient.ProposalResponse response = MlsClient.ProposalResponse.newBuilder().setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(message))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void reInitProposal(final MlsClient.ReInitProposalRequest request, final StreamObserver<MlsClient.ProposalResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.reInitProposalImpl(group, request, (StreamObserver<MlsClient.ProposalResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void reInitCommitImpl(CachedGroup entry, MlsClient.CommitRequest request, StreamObserver<MlsClient.CommitResponse> responseObserver) throws Exception {
        boolean inlineTree = !request.getExternalTree();
        boolean forcePath = request.getForcePath();
        if (request.getByReferenceCount() != 1) {
            throw new Exception("Malformed ReInit CommitRequest");
        }
        byte[] reinitProposal = request.getByReference(0).toByteArray();
        Group shouldBeNull = entry.group.handle(reinitProposal, null);
        if (shouldBeNull != null) {
            throw new Exception("Commit included among proposals");
        }
        SecureRandom random = new SecureRandom();
        byte[] leafSecret = new byte[entry.group.getSuite().getKDF().getHashLength()];
        random.nextBytes(leafSecret);
        Group.CommitOptions commitOptions = new Group.CommitOptions(new ArrayList<Proposal>(), inlineTree, forcePath, null);
        Group.TombstoneWithMessage twm = entry.group.reinitCommit(leafSecret, commitOptions, entry.messageOptions);
        LeafNode leaf = entry.group.getTree().getLeafNode(entry.group.getIndex());
        byte[] identity = leaf.getCredential().getIdentity();
        KeyPackageWithSecrets kpSk = this.newKeyPackage(twm.getSuite(), identity);
        int reinitID = this.storeReinit(kpSk, twm, entry.encryptHandshake);
        byte[] commitBytes = MLSOutputStream.encode(twm.getMessage());
        MlsClient.CommitResponse response = MlsClient.CommitResponse.newBuilder().setCommit(ByteString.copyFrom((byte[])commitBytes)).build();
        entry.pendingCommit = commitBytes;
        entry.pendingGroupID = reinitID;
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void reInitCommit(final MlsClient.CommitRequest request, final StreamObserver<MlsClient.CommitResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.reInitCommitImpl(group, request, (StreamObserver<MlsClient.CommitResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void handlePendingReInitCommitImpl(CachedGroup entry, MlsClient.HandlePendingCommitRequest request, StreamObserver<MlsClient.HandleReInitCommitResponse> responseObserver) throws Exception {
        if (entry.pendingCommit == null || entry.pendingGroupID == -1) {
            throw new Exception("No pending commit to handle");
        }
        int reinitID = entry.pendingGroupID;
        CachedReinit reinit = this.loadReinit(reinitID);
        if (reinit == null) {
            throw new Exception("Internal error: No state for next ID");
        }
        MlsClient.HandleReInitCommitResponse response = MlsClient.HandleReInitCommitResponse.newBuilder().setReinitId(reinitID).setKeyPackage(ByteString.copyFrom((byte[])MLSOutputStream.encode(MLSMessage.keyPackage(reinit.kpSk.keyPackage)))).setEpochAuthenticator(ByteString.copyFrom((byte[])reinit.tombstone.getEpochAuthenticator())).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void handlePendingReInitCommit(final MlsClient.HandlePendingCommitRequest request, final StreamObserver<MlsClient.HandleReInitCommitResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.handlePendingReInitCommitImpl(group, request, (StreamObserver<MlsClient.HandleReInitCommitResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void handleReInitCommitImpl(CachedGroup entry, MlsClient.HandleCommitRequest request, StreamObserver<MlsClient.HandleReInitCommitResponse> responseObserver) throws Exception {
        if (request.getProposalCount() != 1) {
            throw new Exception("Malformed ReInit CommitRequest");
        }
        byte[] reinitMessage = request.getProposal(0).toByteArray();
        Group shouldBeNull = entry.group.handle(reinitMessage, null);
        if (shouldBeNull != null) {
            throw new Exception("Commit included among proposals");
        }
        MLSMessage commit = (MLSMessage)MLSInputStream.decode(request.getCommit().toByteArray(), MLSMessage.class);
        Group.Tombstone tombstone = entry.group.handleReinitCommit(commit);
        LeafNode leafNode = entry.group.getTree().getLeafNode(entry.group.getIndex());
        byte[] identity = leafNode.getCredential().getIdentity();
        KeyPackageWithSecrets kpSk = this.newKeyPackage(tombstone.getSuite(), identity);
        int reinitID = this.storeReinit(kpSk, tombstone, entry.encryptHandshake);
        MlsClient.HandleReInitCommitResponse response = MlsClient.HandleReInitCommitResponse.newBuilder().setReinitId(reinitID).setKeyPackage(ByteString.copyFrom((byte[])MLSOutputStream.encode(MLSMessage.keyPackage(kpSk.keyPackage)))).setEpochAuthenticator(ByteString.copyFrom((byte[])tombstone.getEpochAuthenticator())).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void handleReInitCommit(final MlsClient.HandleCommitRequest request, final StreamObserver<MlsClient.HandleReInitCommitResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.handleReInitCommitImpl(group, request, (StreamObserver<MlsClient.HandleReInitCommitResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void reInitWelcomeImpl(MlsClient.ReInitWelcomeRequest request, StreamObserver<MlsClient.CreateSubgroupResponse> responseObserver) throws Exception {
        CachedReinit reinit = this.loadReinit(request.getReinitId());
        if (reinit == null) {
            throw Status.INVALID_ARGUMENT.withDescription("Unknown reinit ID").asException();
        }
        ArrayList<KeyPackage> keyPackages = new ArrayList<KeyPackage>();
        for (int i = 0; i < request.getKeyPackageCount(); ++i) {
            MLSMessage message = (MLSMessage)MLSInputStream.decode(request.getKeyPackage(i).toByteArray(), MLSMessage.class);
            keyPackages.add(message.keyPackage);
        }
        boolean inlineTree = !request.getExternalTree();
        boolean forcePath = request.getForcePath();
        MlsCipherSuite suite = reinit.tombstone.getSuite();
        byte[] leafSecret = new byte[suite.getKDF().getHashLength()];
        Group.GroupWithMessage gwm = reinit.tombstone.createWelcome(reinit.kpSk.encryptionKeyPair, suite.serializeSignaturePrivateKey(reinit.kpSk.signatureKeyPair.getPrivate()), reinit.kpSk.keyPackage.getLeafNode(), keyPackages, leafSecret, new Group.CommitOptions(null, inlineTree, forcePath, null));
        byte[] welcomeData = MLSOutputStream.encode(gwm.message);
        int stateID = this.storeGroup(gwm.group, reinit.encryptHandshake);
        MlsClient.CreateSubgroupResponse.Builder builder = MlsClient.CreateSubgroupResponse.newBuilder().setStateId(stateID).setWelcome(ByteString.copyFrom((byte[])welcomeData)).setEpochAuthenticator(ByteString.copyFrom((byte[])gwm.group.getEpochAuthenticator()));
        if (!inlineTree) {
            builder.setRatchetTree(ByteString.copyFrom((byte[])MLSOutputStream.encode(gwm.group.getTree())));
        }
        responseObserver.onNext((Object)builder.build());
        responseObserver.onCompleted();
    }

    @Override
    public void reInitWelcome(final MlsClient.ReInitWelcomeRequest request, final StreamObserver<MlsClient.CreateSubgroupResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.reInitWelcomeImpl(request, (StreamObserver<MlsClient.CreateSubgroupResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void handleReInitWelcomeImpl(MlsClient.HandleReInitWelcomeRequest request, StreamObserver<MlsClient.JoinGroupResponse> responseObserver) throws Exception {
        CachedReinit reinit = this.loadReinit(request.getReinitId());
        if (reinit == null) {
            throw Status.INVALID_ARGUMENT.withDescription("Unknown reinit ID").asException();
        }
        MLSMessage welcome = (MLSMessage)MLSInputStream.decode(request.getWelcome().toByteArray(), MLSMessage.class);
        byte[] ratchetTreeBytes = request.getRatchetTree().toByteArray();
        TreeKEMPublicKey ratchetTree = null;
        if (ratchetTreeBytes.length > 0) {
            ratchetTree = (TreeKEMPublicKey)MLSInputStream.decode(ratchetTreeBytes, TreeKEMPublicKey.class);
            ratchetTree.setSuite(welcome.getCipherSuite());
        }
        Group group = reinit.tombstone.handleWelcome(reinit.kpSk.initKeyPair, reinit.kpSk.encryptionKeyPair, reinit.kpSk.signatureKeyPair, reinit.kpSk.keyPackage, welcome, ratchetTree);
        int stateID = this.storeGroup(group, reinit.encryptHandshake);
        MlsClient.JoinGroupResponse response = MlsClient.JoinGroupResponse.newBuilder().setStateId(stateID).setEpochAuthenticator(ByteString.copyFrom((byte[])group.getEpochAuthenticator())).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void handleReInitWelcome(final MlsClient.HandleReInitWelcomeRequest request, final StreamObserver<MlsClient.JoinGroupResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.handleReInitWelcomeImpl(request, (StreamObserver<MlsClient.JoinGroupResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void createBranchImpl(CachedGroup entry, MlsClient.CreateBranchRequest request, StreamObserver<MlsClient.CreateSubgroupResponse> responseObserver) throws Exception {
        ArrayList<KeyPackage> keyPackages = new ArrayList<KeyPackage>();
        for (int i = 0; i < request.getKeyPackagesCount(); ++i) {
            MLSMessage message = (MLSMessage)MLSInputStream.decode(request.getKeyPackages(i).toByteArray(), MLSMessage.class);
            keyPackages.add(message.keyPackage);
        }
        ArrayList<Extension> extList = new ArrayList<Extension>();
        for (int i = 0; i < request.getExtensionsCount(); ++i) {
            Extension ext = new Extension(request.getExtensions(i).getExtensionType(), request.getExtensions(i).getExtensionData().toByteArray());
            extList.add(ext);
        }
        LeafNode leaf = entry.group.getTree().getLeafNode(entry.group.getIndex());
        byte[] identity = leaf.getCredential().getIdentity();
        boolean inlineTree = !request.getExternalTree();
        boolean forcePath = request.getForcePath();
        byte[] groupID = request.getGroupId().toByteArray();
        MlsCipherSuite suite = entry.group.getSuite();
        KeyPackageWithSecrets kpSK = this.newKeyPackage(suite, identity);
        byte[] leafSecret = new byte[suite.getKDF().getHashLength()];
        SecureRandom random = new SecureRandom();
        random.nextBytes(leafSecret);
        Group.GroupWithMessage gwm = entry.group.createBranch(groupID, kpSK.encryptionKeyPair, kpSK.signatureKeyPair, kpSK.keyPackage.getLeafNode(), extList, keyPackages, leafSecret, new Group.CommitOptions(null, inlineTree, forcePath, null));
        int nextID = this.storeGroup(gwm.group, entry.encryptHandshake);
        MlsClient.CreateSubgroupResponse.Builder builder = MlsClient.CreateSubgroupResponse.newBuilder().setStateId(nextID).setWelcome(ByteString.copyFrom((byte[])MLSOutputStream.encode(gwm.message))).setEpochAuthenticator(ByteString.copyFrom((byte[])gwm.group.getEpochAuthenticator()));
        if (!inlineTree) {
            builder.setRatchetTree(ByteString.copyFrom((byte[])MLSOutputStream.encode(gwm.group.getTree())));
        }
        responseObserver.onNext((Object)builder.build());
        responseObserver.onCompleted();
    }

    @Override
    public void createBranch(final MlsClient.CreateBranchRequest request, final StreamObserver<MlsClient.CreateSubgroupResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.createBranchImpl(group, request, (StreamObserver<MlsClient.CreateSubgroupResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void handleBranchImpl(CachedGroup entry, MlsClient.HandleBranchRequest request, StreamObserver<MlsClient.HandleBranchResponse> responseObserver) throws Exception {
        CachedJoin join = this.loadJoin(request.getTransactionId());
        if (join == null) {
            throw Status.INVALID_ARGUMENT.withDescription("Unknown transaction ID").asException();
        }
        MLSMessage welcome = (MLSMessage)MLSInputStream.decode(request.getWelcome().toByteArray(), MLSMessage.class);
        byte[] ratchetTreeBytes = request.getRatchetTree().toByteArray();
        TreeKEMPublicKey ratchetTree = null;
        if (ratchetTreeBytes.length > 0) {
            ratchetTree = (TreeKEMPublicKey)MLSInputStream.decode(ratchetTreeBytes, TreeKEMPublicKey.class);
            ratchetTree.setSuite(welcome.getCipherSuite());
        }
        Group group = entry.group.handleBranch(join.kpSecrets.initKeyPair, join.kpSecrets.encryptionKeyPair, join.kpSecrets.signatureKeyPair, join.kpSecrets.keyPackage, welcome, ratchetTree);
        int stateID = this.storeGroup(group, entry.encryptHandshake);
        MlsClient.HandleBranchResponse response = MlsClient.HandleBranchResponse.newBuilder().setStateId(stateID).setEpochAuthenticator(ByteString.copyFrom((byte[])group.getEpochAuthenticator())).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void handleBranch(final MlsClient.HandleBranchRequest request, final StreamObserver<MlsClient.HandleBranchResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.handleBranchImpl(group, request, (StreamObserver<MlsClient.HandleBranchResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void newMemberAddProposalImpl(MlsClient.NewMemberAddProposalRequest request, StreamObserver<MlsClient.NewMemberAddProposalResponse> responseObserver) throws Exception {
        MLSMessage groupInfoMsg = (MLSMessage)MLSInputStream.decode(request.getGroupInfo().toByteArray(), MLSMessage.class);
        GroupInfo groupInfo = groupInfoMsg.groupInfo;
        MlsCipherSuite suite = groupInfo.getSuite();
        KeyPackageWithSecrets kpSk = this.newKeyPackage(suite, request.getIdentity().toByteArray());
        byte[] initSk = suite.getHPKE().serializePrivateKey(kpSk.initKeyPair.getPrivate());
        byte[] encryptionSk = suite.getHPKE().serializePrivateKey(kpSk.encryptionKeyPair.getPrivate());
        byte[] signatureSk = suite.serializeSignaturePrivateKey(kpSk.signatureKeyPair.getPrivate());
        MLSMessage proposal = Group.newMemberAdd(groupInfo.getGroupID(), groupInfo.getEpoch(), kpSk.keyPackage, kpSk.signatureKeyPair);
        int joinID = this.storeJoin(kpSk);
        MlsClient.NewMemberAddProposalResponse response = MlsClient.NewMemberAddProposalResponse.newBuilder().setInitPriv(ByteString.copyFrom((byte[])initSk)).setEncryptionPriv(ByteString.copyFrom((byte[])encryptionSk)).setSignaturePriv(ByteString.copyFrom((byte[])signatureSk)).setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(proposal))).setTransactionId(joinID).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void newMemberAddProposal(final MlsClient.NewMemberAddProposalRequest request, final StreamObserver<MlsClient.NewMemberAddProposalResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.newMemberAddProposalImpl(request, (StreamObserver<MlsClient.NewMemberAddProposalResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void createExternalSignerImpl(MlsClient.CreateExternalSignerRequest request, StreamObserver<MlsClient.CreateExternalSignerResponse> responseObserver) throws Exception {
        MlsCipherSuite suite = MlsCipherSuite.getSuite((short)request.getCipherSuite());
        AsymmetricCipherKeyPair sigSk = suite.generateSignatureKeyPair();
        Credential cred = Credential.forBasic(request.getIdentity().toByteArray());
        ExternalSender extSender = new ExternalSender(suite.serializeSignaturePublicKey(sigSk.getPublic()), cred);
        int signerID = this.storeSigner(suite.serializeSignaturePrivateKey(sigSk.getPrivate()));
        MlsClient.CreateExternalSignerResponse response = MlsClient.CreateExternalSignerResponse.newBuilder().setExternalSender(ByteString.copyFrom((byte[])MLSOutputStream.encode(extSender))).setSignerId(signerID).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void createExternalSigner(final MlsClient.CreateExternalSignerRequest request, final StreamObserver<MlsClient.CreateExternalSignerResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.createExternalSignerImpl(request, (StreamObserver<MlsClient.CreateExternalSignerResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void addExternalSignerImpl(CachedGroup entry, MlsClient.AddExternalSignerRequest request, StreamObserver<MlsClient.ProposalResponse> responseObserver) throws Exception {
        byte[] extSender = request.getExternalSender().toByteArray();
        ArrayList<Extension> extList = new ArrayList<Extension>(entry.group.getExtensions());
        ArrayList<ExternalSender> extSenders = new ArrayList<ExternalSender>();
        for (Extension ext : extList) {
            if (ext.extensionType != ExtensionType.EXTERNAL_SENDERS) continue;
            extSenders = ext.getSenders();
        }
        extList = new ArrayList();
        extSenders.add((ExternalSender)MLSInputStream.decode(extSender, ExternalSender.class));
        extList.add(Extension.externalSender(extSenders));
        MLSMessage proposal = entry.group.groupContextExtensions(extList, entry.messageOptions);
        MlsClient.ProposalResponse response = MlsClient.ProposalResponse.newBuilder().setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(proposal))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void addExternalSigner(final MlsClient.AddExternalSignerRequest request, final StreamObserver<MlsClient.ProposalResponse> responseObserver) {
        this.stateWrap(new FunctionWithState(){

            @Override
            public void run(CachedGroup group) throws Exception {
                MLSClientImpl.this.addExternalSignerImpl(group, request, (StreamObserver<MlsClient.ProposalResponse>)responseObserver);
            }
        }, request, responseObserver);
    }

    private void externalSignerProposalImpl(MlsClient.ExternalSignerProposalRequest request, StreamObserver<MlsClient.ProposalResponse> responseObserver) throws Exception {
        byte[] groupMsgData = request.getGroupInfo().toByteArray();
        MLSMessage groupMsg = (MLSMessage)MLSInputStream.decode(groupMsgData, MLSMessage.class);
        GroupInfo groupInfo = groupMsg.groupInfo;
        MlsCipherSuite suite = groupInfo.getSuite();
        byte[] groupID = groupInfo.getGroupID();
        long epoch = groupInfo.getEpoch();
        byte[] treeData = request.getRatchetTree().toByteArray();
        TreeKEMPublicKey tree = (TreeKEMPublicKey)MLSInputStream.decode(treeData, TreeKEMPublicKey.class);
        byte[] sigPriv = this.loadSigner(request.getSignerId());
        if (sigPriv == null) {
            throw new Exception("Unknown signer ID");
        }
        byte[] sigPub = suite.serializeSignaturePublicKey(suite.deserializeSignaturePrivateKey(sigPriv).getPublic());
        List<Object> extSenders = new ArrayList();
        for (Extension ext : groupInfo.getGroupContext().getExtensions()) {
            if (ext.extensionType != ExtensionType.EXTERNAL_SENDERS) continue;
            extSenders = ext.getSenders();
        }
        int sigIndex = -1;
        for (int i = 0; i < extSenders.size(); ++i) {
            if (!Arrays.equals(((ExternalSender)extSenders.get(i)).getSignatureKey(), sigPub)) continue;
            sigIndex = i;
        }
        if (sigIndex == -1) {
            throw new Exception("Requested signer not allowed for this group");
        }
        Proposal proposal = this.proposalFromDescription(suite, groupID, tree, request.getDescription());
        MLSMessage signedProposal = MLSMessage.externalProposal(suite, groupID, epoch, proposal, sigIndex, sigPriv);
        MlsClient.ProposalResponse response = MlsClient.ProposalResponse.newBuilder().setProposal(ByteString.copyFrom((byte[])MLSOutputStream.encode(signedProposal))).build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void externalSignerProposal(final MlsClient.ExternalSignerProposalRequest request, final StreamObserver<MlsClient.ProposalResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.externalSignerProposalImpl(request, (StreamObserver<MlsClient.ProposalResponse>)responseObserver);
            }
        }, responseObserver);
    }

    private void freeImpl(MlsClient.FreeRequest request, StreamObserver<MlsClient.FreeResponse> responseObserver) throws StatusException {
        int stateID = request.getStateId();
        if (!this.groupCache.containsKey(stateID)) {
            throw Status.NOT_FOUND.withDescription("Unknown state").asException();
        }
        this.removeGroup(stateID);
        MlsClient.FreeResponse response = MlsClient.FreeResponse.newBuilder().build();
        responseObserver.onNext((Object)response);
        responseObserver.onCompleted();
    }

    @Override
    public void free(final MlsClient.FreeRequest request, final StreamObserver<MlsClient.FreeResponse> responseObserver) {
        MLSClientImpl.catchWrap(new Function(){

            @Override
            public void run() throws Exception {
                MLSClientImpl.this.freeImpl(request, (StreamObserver<MlsClient.FreeResponse>)responseObserver);
            }
        }, responseObserver);
    }

    class CachedGroup {
        Group group;
        boolean encryptHandshake;
        Group.MessageOptions messageOptions;
        byte[] pendingCommit;
        int pendingGroupID;

        public CachedGroup(Group group, boolean encryptHandshake) {
            this.group = group;
            this.encryptHandshake = encryptHandshake;
            this.messageOptions = new Group.MessageOptions(encryptHandshake, new byte[0], 0);
        }

        public void resetPending() {
            this.pendingCommit = null;
            this.pendingGroupID = -1;
        }
    }

    class CachedJoin {
        KeyPackageWithSecrets kpSecrets;
        Map<Secret, byte[]> externalPsks;

        public CachedJoin(KeyPackageWithSecrets kpSecrets) {
            this.kpSecrets = kpSecrets;
            this.externalPsks = new HashMap<Secret, byte[]>();
        }
    }

    class CachedReinit {
        KeyPackageWithSecrets kpSk;
        Group.Tombstone tombstone;
        boolean encryptHandshake;

        public CachedReinit(KeyPackageWithSecrets kpSk, Group.Tombstone tombstone, boolean encryptHandshake) {
            this.kpSk = kpSk;
            this.tombstone = tombstone;
            this.encryptHandshake = encryptHandshake;
        }
    }

    @FunctionalInterface
    public static interface Function {
        public void run() throws Exception;
    }

    @FunctionalInterface
    public static interface FunctionWithState {
        public void run(CachedGroup var1) throws Exception;
    }
}

