/*
 * Decompiled with CFR 0.152.
 */
package tech.kwik.agent15.handshake;

import java.io.ByteArrayInputStream;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import tech.kwik.agent15.TlsConstants;
import tech.kwik.agent15.alert.BadCertificateAlert;
import tech.kwik.agent15.alert.DecodeErrorException;
import tech.kwik.agent15.handshake.HandshakeMessage;

public class CertificateMessage
extends HandshakeMessage {
    private static final int MINIMUM_MESSAGE_SIZE = 13;
    private byte[] requestContext;
    private X509Certificate endEntityCertificate;
    private List<X509Certificate> certificateChain = new ArrayList<X509Certificate>();
    private byte[] raw;

    public CertificateMessage(X509Certificate certificate) {
        this.requestContext = new byte[0];
        this.endEntityCertificate = certificate;
        this.certificateChain = certificate != null ? List.of(certificate) : Collections.emptyList();
        this.serialize();
    }

    public CertificateMessage(List<X509Certificate> certificateChain) {
        Objects.requireNonNull(certificateChain);
        if (certificateChain.size() < 1) {
            throw new IllegalArgumentException();
        }
        this.requestContext = new byte[0];
        this.endEntityCertificate = certificateChain.get(0);
        this.certificateChain = certificateChain;
        this.serialize();
    }

    public CertificateMessage(byte[] requestContext, X509Certificate certificate) {
        Objects.requireNonNull(certificate);
        this.requestContext = requestContext;
        this.endEntityCertificate = certificate;
        this.certificateChain = List.of(certificate);
        this.serialize();
    }

    public CertificateMessage() {
    }

    @Override
    public TlsConstants.HandshakeType getType() {
        return TlsConstants.HandshakeType.certificate;
    }

    public CertificateMessage parse(ByteBuffer buffer) throws DecodeErrorException, BadCertificateAlert {
        int startPosition = buffer.position();
        int remainingLength = this.parseHandshakeHeader(buffer, TlsConstants.HandshakeType.certificate, 13);
        try {
            int certificateRequestContextSize = buffer.get() & 0xFF;
            if (certificateRequestContextSize > 0) {
                this.requestContext = new byte[certificateRequestContextSize];
                buffer.get(this.requestContext);
            } else {
                this.requestContext = new byte[0];
            }
            this.parseCertificateEntries(buffer);
            this.raw = new byte[4 + remainingLength];
            buffer.position(startPosition);
            buffer.get(this.raw);
            return this;
        }
        catch (BufferUnderflowException notEnoughBytes) {
            throw new DecodeErrorException("message underflow");
        }
    }

    private int parseCertificateEntries(ByteBuffer buffer) throws BadCertificateAlert {
        int certificateListSize;
        int extensionsSize;
        int certCount = 0;
        for (int remainingCertificateBytes = certificateListSize = (buffer.get() & 0xFF) << 16 | (buffer.get() & 0xFF) << 8 | buffer.get() & 0xFF; remainingCertificateBytes > 0; remainingCertificateBytes -= 2 + extensionsSize) {
            int certSize = (buffer.get() & 0xFF) << 16 | (buffer.get() & 0xFF) << 8 | buffer.get() & 0xFF;
            byte[] certificateData = new byte[certSize];
            buffer.get(certificateData);
            if (certSize > 0) {
                try {
                    CertificateFactory cf = CertificateFactory.getInstance("X.509");
                    X509Certificate certificate = (X509Certificate)cf.generateCertificate(new ByteArrayInputStream(certificateData));
                    if (certCount == 0) {
                        this.endEntityCertificate = certificate;
                    }
                    this.certificateChain.add(certificate);
                }
                catch (CertificateException e) {
                    throw new BadCertificateAlert("could not parse certificate");
                }
            }
            remainingCertificateBytes -= 3 + certSize;
            ++certCount;
            extensionsSize = buffer.getShort() & 0xFFFF;
            byte[] extensionData = new byte[extensionsSize];
            buffer.get(extensionData);
        }
        return certCount;
    }

    private void serialize() {
        int nrOfCerts = this.certificateChain.size();
        List<byte[]> encodedCerts = this.certificateChain.stream().map(cert -> this.encode((X509Certificate)cert)).collect(Collectors.toList());
        int msgSize = 8 + nrOfCerts * 5 + encodedCerts.stream().mapToInt(bytes -> ((byte[])bytes).length).sum();
        ByteBuffer buffer = ByteBuffer.allocate(msgSize);
        buffer.putInt(TlsConstants.HandshakeType.certificate.value << 24 | msgSize - 4);
        buffer.put((byte)0);
        buffer.put((byte)0);
        buffer.putShort((short)(msgSize - 4 - 1 - 3));
        encodedCerts.forEach(encodedCert -> {
            if (((byte[])encodedCert).length > 65520) {
                throw new RuntimeException("Certificate size not supported");
            }
            buffer.put((byte)0);
            buffer.putShort((short)((byte[])encodedCert).length);
            buffer.put((byte[])encodedCert);
            buffer.putShort((short)0);
        });
        this.raw = buffer.array();
    }

    byte[] encode(X509Certificate certificate) {
        try {
            return certificate.getEncoded();
        }
        catch (CertificateEncodingException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public byte[] getBytes() {
        return this.raw;
    }

    public byte[] getRequestContext() {
        return this.requestContext;
    }

    public X509Certificate getEndEntityCertificate() {
        return this.endEntityCertificate;
    }

    public List<X509Certificate> getCertificateChain() {
        return this.certificateChain;
    }
}

