package org.jfrog.security.file;

import org.bouncycastle.asn1.pkcs.PrivateKeyInfo;
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.openssl.PEMKeyPair;
import org.bouncycastle.openssl.PEMParser;
import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
import org.bouncycastle.openssl.jcajce.JcaPEMWriter;
import org.jfrog.security.ssl.CertificateKeyHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.security.cert.CertificateEncodingException;
import javax.security.cert.X509Certificate;
import java.io.*;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

/**
 * @author Yinon Avraham
 * Created on 09/10/2016.
 */
public abstract class PemHelper {
    private static final Logger log = LoggerFactory.getLogger(PemHelper.class);

    private PemHelper() {}

    public static void savePrivateKey(File file, PrivateKey privateKey) throws IOException {
        savePemObjects(file, privateKey);
    }

    public static PrivateKey readPrivateKey(File file) throws IOException {
        return readPemObject(new FileReader(file), PrivateKey.class);
    }

    public static String privateKeyAsPemString(PrivateKey privateKey) {
        return objectAsString(privateKey);
    }

    public static String certificateAsPemString(Certificate certificate) {
        return objectAsString(certificate);
    }

    public static PrivateKey readPrivateKey(String content) throws IOException {
        return readPemObject(new StringReader(content), PrivateKey.class);
    }

    public static void saveKeyPair(File file, KeyPair keyPair) throws IOException {
        savePemObjects(file, keyPair.getPublic(), keyPair.getPrivate());
    }

    public static KeyPair readKeyPair(File file) throws IOException {
        List<Object> objects = readPemObjects(new FileReader(file), PublicKey.class, PrivateKey.class);
        PublicKey publicKey = (PublicKey) objects.get(0);
        PrivateKey privateKey = (PrivateKey) objects.get(1);
        return new KeyPair(publicKey, privateKey);
    }

    public static CertificateKeyHolder readCertificateAndPrivateKey(String content) throws CertificateException {
        List<Object> objects;
        try {
            objects = parseAllObjects(new StringReader(content));
            CertificateKeyHolder certKeyHolder = new CertificateKeyHolder();
            boolean foundCertificate = false;
            boolean foundPrivate = false;
            for (Object curr : objects) {
                if (!foundPrivate && (curr instanceof PrivateKeyInfo || curr instanceof PEMKeyPair)) {
                    PrivateKey privateKey = fromParsedObject(curr, PrivateKey.class);
                    certKeyHolder.setKey(privateKey);
                    foundPrivate = true;
                } else if (!foundCertificate && curr instanceof X509CertificateHolder) {
                    Certificate certificate = fromParsedObject(curr, Certificate.class);
                    certKeyHolder.setCertificate(certificate);
                    foundCertificate = true;
                }
            }
            if (foundCertificate && foundPrivate) {
                return certKeyHolder;
            } else {
                throw new CertificateException("Missing private key or certificate");
            }
        } catch (IOException e) {
            throw new CertificateException("An error occurred while reading content", e);
        }
    }

    public static void saveCertificate(File file, X509Certificate certificate) throws IOException {
        try {
            CertificateFactory certificateFactory = CertificateFactory.getInstance("X509");
            ByteArrayInputStream certificateByteStream = new ByteArrayInputStream(certificate.getEncoded());
            Certificate serializableCertificate = certificateFactory.generateCertificate(certificateByteStream);
            saveCertificate(file, serializableCertificate);
        } catch (CertificateException | CertificateEncodingException e) {
            throw new IllegalArgumentException("Unexpected certificate error.", e);
        }
    }

    public static void saveCertificate(File file, Certificate certificate) throws IOException {
        savePemObjects(file, certificate);
    }

    public static Certificate readCertificate(File file) throws IOException {
        return readPemObject(new FileReader(file), Certificate.class);
    }

    public static Certificate readCertificate(String content) throws IOException {
        return readPemObject(new StringReader(content), Certificate.class);
    }

    private static <T> T readPemObject(Reader reader, Class<T> targetClass)
            throws IOException {
        try (PEMParser pemParser = new PEMParser(reader)) {
            Object object = pemParser.readObject();
            return fromParsedObject(object, targetClass);
        }
    }

    private static List<Object> parseAllObjects(Reader reader) throws IOException {
        List<Object> objects = new LinkedList<>();
        boolean hasMore = true;
        try (PEMParser pemParser = new PEMParser(reader)) {
            while (hasMore) {
                try {
                    Object pemObject = pemParser.readObject();
                    if (pemObject != null) {
                        objects.add(pemObject);
                    } else {
                        hasMore = false;
                    }
                } catch (IOException e) {
                    log.debug("Could not read PEM object.", e.getMessage());
                }
            }
        }
        return objects;
    }

    private static List<Object> readPemObjects(Reader reader, Class<?> targetClass1, Class<?> targetClass2)
            throws IOException {
        List<Object> objects = new ArrayList<>();
        try (PEMParser pemParser = new PEMParser(reader)) {
            objects.add(fromParsedObject(pemParser.readObject(), targetClass1));
            objects.add(fromParsedObject(pemParser.readObject(), targetClass2));
        }
        return objects;
    }

    private static <T> T fromParsedObject(Object object, Class<T> targetClass) throws IOException {
        Object result = object;
        JcaPEMKeyConverter converter = new JcaPEMKeyConverter();
        if (object instanceof PEMKeyPair) {
            result = converter.getKeyPair((PEMKeyPair) object);
        } else if (object instanceof PrivateKeyInfo) {
            result = converter.getPrivateKey((PrivateKeyInfo) object);
        } else if (object instanceof SubjectPublicKeyInfo) {
            result = converter.getPublicKey((SubjectPublicKeyInfo) object);
        } else if (object instanceof X509CertificateHolder) {
            try {
                CertificateFactory certFactory = CertificateFactory.getInstance("X509");
                byte[] encodedBytes = ((X509CertificateHolder) object).toASN1Structure().getEncoded();
                result = certFactory.generateCertificate(new ByteArrayInputStream(encodedBytes));
            } catch (CertificateException e) {
                throw new RuntimeException("Failed to convert parsed PEM object to certificate.", e);
            }
        }
        //Private key is parsed as key-pair, so if we have a key-pair object but a private key is expected - get it.
        if (result instanceof KeyPair && PrivateKey.class.isAssignableFrom(targetClass)) {
            result = ((KeyPair) result).getPrivate();
        }
        if (!targetClass.isAssignableFrom(result.getClass())){
            throw new IllegalArgumentException("Object type is not as expected " +
                    "(parsed type: " + object.getClass() + ", adjusted: " + result.getClass() + ", expected: " + targetClass + ")");
        }
        return targetClass.cast(result);
    }

    private static void savePemObjects(File file, Object object, Object... objects) throws IOException {
        try (JcaPEMWriter pemWriter = new JcaPEMWriter(new FileWriter(file))) {
            pemWriter.writeObject(object);
            for (Object obj : objects) {
                pemWriter.writeObject(obj);
            }
            pemWriter.flush();
            pemWriter.close();
        }
    }

    public static String objectAsString(Object object) {
        StringWriter writer = new StringWriter();
        try (JcaPEMWriter pemWriter = new JcaPEMWriter(writer)) {
            pemWriter.writeObject(object);
            pemWriter.flush();
            pemWriter.close();
            return writer.toString();
        } catch (IOException e) {
            throw new IllegalArgumentException("Failed to write object in PEM format.", e);
        }
    }
}
