package com.atlassian.aws.ec2.caches;

import com.amazonaws.services.ec2.AmazonEC2Async;
import com.amazonaws.services.ec2.model.GetPasswordDataRequest;
import com.amazonaws.services.ec2.model.GetPasswordDataResult;
import com.atlassian.aws.ec2.model.InstanceId;
import com.atlassian.aws.utils.CryptoUtils;
import com.google.common.base.Throwables;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ComputationException;
import org.apache.commons.lang3.StringUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import javax.crypto.Cipher;
import java.io.File;
import java.security.Security;
import java.util.Base64;

public class InstancePasswordCache
{
    private static class PasswordNotAvailableException extends IllegalStateException
    {
        public PasswordNotAvailableException(final String msg)
        {
            super(msg);
        }
    }

    public static final String CIPHER_SPEC = "RSA/ECB/PKCS1Padding";
    private final LoadingCache<KeyInstanceIdHolder, String> instanceId2password;

    public InstancePasswordCache(@NotNull final AmazonEC2Async ec2Client)
    {
        Security.addProvider(new BouncyCastleProvider());

        instanceId2password = makePasswordComputingMap(ec2Client, makeCipherMap());
    }

    private LoadingCache<File, Cipher> makeCipherMap()
    {
        return CacheBuilder.newBuilder().build(new CacheLoader<File, Cipher>()
        {
            @Override
            public Cipher load(final File keyFile) throws Exception
            {
                return CryptoUtils.getCipherForKey(keyFile, CIPHER_SPEC);
            }
        });
    }

    private LoadingCache<KeyInstanceIdHolder, String> makePasswordComputingMap(final AmazonEC2Async ec2Client, final LoadingCache<File, Cipher> key2cipher)
    {
        return CacheBuilder.newBuilder().build(new CacheLoader<KeyInstanceIdHolder, String>()
        {
            @Override
            public String load(final KeyInstanceIdHolder args)
            {
                final GetPasswordDataRequest request = new GetPasswordDataRequest(args.getInstanceId().getId());
                final GetPasswordDataResult passwordData = ec2Client.getPasswordData(request);

                final String base64encodedPassword = passwordData.getPasswordData();
                if (StringUtils.isEmpty(base64encodedPassword))
                {
                    throw new ComputationException(new PasswordNotAvailableException("Password is not (yet) available. If password generation was enabled for this image, note that generation and encryption takes a few moments. Please wait up to 15 minutes after launching an instance before trying to retrieve the generated password."));
                }
                final byte[] encodedPassword = Base64.getMimeDecoder().decode(base64encodedPassword);

                try
                {
                    final Cipher cipherForKey = key2cipher.get(args.getKeyFile());
                    return new String(cipherForKey.doFinal(encodedPassword));
                }
                catch (Exception e)
                {
                    throw new ComputationException(e);
                }
            }
        });
    }

    @Nullable
    public String getPassword(final File keyFile, final InstanceId instanceId)
    {
        try
        {
            return instanceId2password.get(new KeyInstanceIdHolder(keyFile, instanceId));
        }
        catch (Exception e)
        {
            if (Throwables.getRootCause(e) instanceof PasswordNotAvailableException)
            {
                return null;
            }
            throw new IllegalArgumentException(e);
        }
    }

    private static class KeyInstanceIdHolder
    {
        private final File keyFile;
        private final InstanceId instanceId;

        KeyInstanceIdHolder(final File keyFile, final InstanceId instanceId)
        {
            this.keyFile = keyFile;
            this.instanceId = instanceId;
        }

        public InstanceId getInstanceId()
        {
            return instanceId;
        }

        public File getKeyFile()
        {
            return keyFile;
        }

        @Override
        public boolean equals(final Object o)
        {
            if (this == o)
            {
                return true;
            }
            if (o == null || getClass() != o.getClass())
            {
                return false;
            }

            final KeyInstanceIdHolder that = (KeyInstanceIdHolder) o;

            if (!instanceId.equals(that.instanceId))
            {
                return false;
            }
            if (!keyFile.equals(that.keyFile))
            {
                return false;
            }

            return true;
        }

        @Override
        public int hashCode()
        {
            int result = keyFile.hashCode();
            result = 31 * result + instanceId.hashCode();
            return result;
        }
    }
}
