package com.atlassian.aws;

import com.amazonaws.AmazonClientException;
import com.amazonaws.AmazonServiceException;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.services.ec2.AmazonEC2Async;
import com.amazonaws.services.ec2.AmazonEC2AsyncClient;
import com.amazonaws.services.ec2.model.AccountAttribute;
import com.amazonaws.services.ec2.model.Address;
import com.amazonaws.services.ec2.model.AttachVolumeRequest;
import com.amazonaws.services.ec2.model.AvailabilityZone;
import com.amazonaws.services.ec2.model.CancelSpotInstanceRequestsRequest;
import com.amazonaws.services.ec2.model.CreateKeyPairRequest;
import com.amazonaws.services.ec2.model.CreateKeyPairResult;
import com.amazonaws.services.ec2.model.CreateSecurityGroupRequest;
import com.amazonaws.services.ec2.model.CreateSecurityGroupResult;
import com.amazonaws.services.ec2.model.CreateTagsRequest;
import com.amazonaws.services.ec2.model.CreateVolumeRequest;
import com.amazonaws.services.ec2.model.CreateVolumeResult;
import com.amazonaws.services.ec2.model.DeleteVolumeRequest;
import com.amazonaws.services.ec2.model.DescribeAccountAttributesResult;
import com.amazonaws.services.ec2.model.DescribeAddressesRequest;
import com.amazonaws.services.ec2.model.DescribeAddressesResult;
import com.amazonaws.services.ec2.model.DescribeKeyPairsRequest;
import com.amazonaws.services.ec2.model.DescribeKeyPairsResult;
import com.amazonaws.services.ec2.model.DescribeSecurityGroupsRequest;
import com.amazonaws.services.ec2.model.DescribeSecurityGroupsResult;
import com.amazonaws.services.ec2.model.DescribeSpotPriceHistoryRequest;
import com.amazonaws.services.ec2.model.EbsInstanceBlockDeviceSpecification;
import com.amazonaws.services.ec2.model.Filter;
import com.amazonaws.services.ec2.model.GetConsoleOutputRequest;
import com.amazonaws.services.ec2.model.GetConsoleOutputResult;
import com.amazonaws.services.ec2.model.Image;
import com.amazonaws.services.ec2.model.Instance;
import com.amazonaws.services.ec2.model.InstanceBlockDeviceMappingSpecification;
import com.amazonaws.services.ec2.model.KeyPair;
import com.amazonaws.services.ec2.model.KeyPairInfo;
import com.amazonaws.services.ec2.model.ModifyInstanceAttributeRequest;
import com.amazonaws.services.ec2.model.SecurityGroup;
import com.amazonaws.services.ec2.model.SpotInstanceRequest;
import com.amazonaws.services.ec2.model.SpotPrice;
import com.amazonaws.services.ec2.model.Subnet;
import com.amazonaws.services.ec2.model.Tag;
import com.amazonaws.services.ec2.model.TerminateInstancesRequest;
import com.amazonaws.services.ec2.model.Volume;
import com.amazonaws.services.ec2.model.Vpc;
import com.atlassian.aws.ec2.AmazonEc2Utils;
import com.atlassian.aws.ec2.EC2InstanceListener;
import com.atlassian.aws.ec2.InstanceLaunchConfiguration;
import com.atlassian.aws.ec2.Protocol;
import com.atlassian.aws.ec2.RemoteEC2Instance;
import com.atlassian.aws.ec2.RemoteEC2InstanceImpl;
import com.atlassian.aws.ec2.SpotPrices;
import com.atlassian.aws.ec2.awssdk.AwsSupportConstants;
import com.atlassian.aws.ec2.caches.ImageCache;
import com.atlassian.aws.ec2.caches.InstanceCache;
import com.atlassian.aws.ec2.caches.InstancePasswordCache;
import com.atlassian.aws.ec2.caches.SpotRequestCache;
import com.atlassian.aws.ec2.caches.SubnetCache;
import com.atlassian.aws.ec2.caches.VolumeCache;
import com.atlassian.aws.ec2.caches.VpcCache;
import com.atlassian.aws.ec2.model.InstanceId;
import com.atlassian.aws.ec2.model.VpcId;
import com.atlassian.aws.utils.Eithers;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.Iterables;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.io.File;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

class AWSAccountImpl implements AWSAccount
{
    private static final Logger log = Logger.getLogger(AWSAccountImpl.class);

    /**
     * <p>The period of time (in seconds) to wait between requests for EC2 instance status updates.</p>
     */
    private static final int DEFAULT_SUPERVISION_INTERVAL_SECONDS = 20;
    private final AwsSupportConstants.Region region;

    private int supervisionIntervalSeconds = DEFAULT_SUPERVISION_INTERVAL_SECONDS;

    /**
     * <p>The number of successive attempts to obtain EC2 instance status updates that are allowed to fail before the
     * instance is considered to have failed, and an attempt is made to terminate it.</p>
     */
    private static final int MAX_SUCCESSIVE_SUPERVISION_FAILURES = 10;

    private final ScheduledExecutorService scheduledExecutorService;

    private final AmazonEC2Async asyncEc2Client;
    private final AmazonEC2AsyncClient deprecatedEc2Client;
    private final SpotRequestCache spotRequestCache;
    private final InstanceCache instanceCache;
    private final VpcCache vpcCache;
    private final VolumeCache volumeCache;
    private final SubnetCache subnetCache;
    private final InstancePasswordCache instancePasswordCache;
    private final ImageCache imageCache;

    final Supplier<Pair<AWSException, Map<String, AvailabilityZone>>> availabilityZones = Suppliers.memoizeWithExpiration(
            new Supplier<Pair<AWSException, Map<String, AvailabilityZone>>>()
            {
                @Override
                public Pair<AWSException, Map<String, AvailabilityZone>> get()
                {
                    final List<AvailabilityZone> availabilityZones;
                    try
                    {
                        availabilityZones = asyncEc2Client.describeAvailabilityZones().getAvailabilityZones();
                    }
                    catch (final AmazonClientException exception)
                    {
                        return Eithers.left(new AWSException("Failed to query EC2 for availability zones descriptions.", exception));
                    }

                    final Map<String, AvailabilityZone> availabilityZonesMap = availabilityZones.stream()
                            .collect(Collectors.toMap(AvailabilityZone::getZoneName, t->t));

                    return Eithers.right(availabilityZonesMap);
                }
            },
            1, TimeUnit.MINUTES);


    private final Supplier<Map<String, AccountAttribute>> accountAttributes = Suppliers.memoizeWithExpiration(
            new Supplier<Map<String, AccountAttribute>>()
            {
                @Override
                public Map<String, AccountAttribute> get()
                {
                    final DescribeAccountAttributesResult describeAccountAttributesResult = asyncEc2Client.describeAccountAttributes();
                    final List<AccountAttribute> accountAttributes = describeAccountAttributesResult.getAccountAttributes();
                    return accountAttributes.stream().collect(Collectors.toMap(AccountAttribute::getAttributeName, t->t));
                }
            }, 1, TimeUnit.MINUTES);

    private final Supplier<SpotPrices> spotPrices = Suppliers.memoizeWithExpiration(new Supplier<SpotPrices>()
    {
        @Override
        public SpotPrices get()
        {
            final DescribeSpotPriceHistoryRequest request = new DescribeSpotPriceHistoryRequest()
                    .withStartTime(new Date());

            List<SpotPrice> spotPrices = asyncEc2Client.describeSpotPriceHistory(request).getSpotPriceHistory();

            return new SpotPrices(spotPrices);
        }
    }, 1, TimeUnit.MINUTES);

    AWSAccountImpl(final Ec2ClientFactory ec2ClientFactory,
                   final ScheduledExecutorService scheduledExecutorService,
                   final AWSCredentials awsCredentials, final AwsSupportConstants.Region region)
    {
        if (StringUtils.isBlank(awsCredentials.getAWSAccessKeyId()) || StringUtils.isBlank(awsCredentials.getAWSSecretKey()))
        {
            throw new IllegalArgumentException("awsAccessId and awsSecretKey must be specified.");
        }
        this.scheduledExecutorService = scheduledExecutorService;
        deprecatedEc2Client = ec2ClientFactory.newAwsAsyncClient(region, awsCredentials, scheduledExecutorService);
        asyncEc2Client = CallTimingProxy.wrap(deprecatedEc2Client, AmazonEC2Async.class);

        spotRequestCache = new SpotRequestCache(asyncEc2Client);
        instanceCache = new InstanceCache(asyncEc2Client);
        vpcCache = new VpcCache(asyncEc2Client);
        volumeCache = new VolumeCache(asyncEc2Client);
        instancePasswordCache = new InstancePasswordCache(asyncEc2Client);
        subnetCache = new SubnetCache(asyncEc2Client);
        imageCache = new ImageCache(asyncEc2Client);
        this.region = region;
    }

    @Nullable
    @Override
    public String getAccountValidationError() throws AWSException
    {
        try
        {
            asyncEc2Client.describeAvailabilityZones();
        }
        catch (final AmazonServiceException e)
        {
            if (AmazonServiceErrorCode.UNAUTHORISED_OPERATION.is(e))
            {
                return null;
            }
            log.info("Unable to validate account: ", e);
            return e.getMessage();
        }
        catch (final AmazonClientException e)
        {
            throw new AWSException("Failed to determine the validity of AWS credentials.", e);
        }
        return null;
    }


    /**
     * Gets the console output for an instance.
     * @param instanceId
     * @return null if an error occurred while retrieving the log.
     */
    @Override
    @Nullable
    public String getConsoleOutput(String instanceId)
    {
        try
        {
            final GetConsoleOutputResult consoleOutput = asyncEc2Client.getConsoleOutput(new GetConsoleOutputRequest(instanceId));
            final String base64Log = StringUtils.trimToEmpty(consoleOutput.getOutput());
            return new String(Base64.getMimeDecoder().decode(base64Log), StandardCharsets.UTF_8);
        }
        catch (AmazonClientException e)
        {
            log.warn("Unable to get console output for " + instanceId + ". Null being returned.", e);
            return null;
        }
    }
    
    @Override
    @NotNull
    public Collection<com.amazonaws.services.ec2.model.Instance> getAllInstances() throws AWSException
    {
        try
        {
            Collection<com.amazonaws.services.ec2.model.Instance> describeInstancesResult = instanceCache.describe();

            return describeInstancesResult.stream()
                    .filter(instance -> !AwsSupportConstants.InstanceStateName.Terminated.is(instance.getState()))
                    .collect(Collectors.toList());
        }
        catch (AmazonClientException e)
        {
            throw new AWSException("Unable to retrieve the list of running elastic instances.", e);
        }
    }

    @Override
    @NotNull
    public Collection<SpotInstanceRequest> describePendingSpotInstanceRequests(final String ... spotInstanceRequestIds)
    {
        final Collection<SpotInstanceRequest> spotInstanceRequests = spotRequestCache.describe(spotInstanceRequestIds);
        return spotInstanceRequests.stream()
                .filter(r -> AwsSupportConstants.SpotInstanceRequestState.OPEN.is(r.getState()))
                .collect(Collectors.toList());
    }

    @NotNull
    @Override
    public RemoteEC2Instance newEC2Instance(@NotNull final InstanceLaunchConfiguration instanceLaunchConfiguration, final EC2InstanceListener listener)
    {
        final RemoteEC2InstanceImpl ec2Instance =
                new RemoteEC2InstanceImpl(instanceLaunchConfiguration,
                        supervisionIntervalSeconds, MAX_SUCCESSIVE_SUPERVISION_FAILURES,
                        listener,
                        this,
                        scheduledExecutorService);
        return ec2Instance;
    }

    @NotNull
    @Override
    public List<Address> describeAddresses(final String ... domains)
    {
        final DescribeAddressesRequest describeAddressesRequest = new DescribeAddressesRequest();
        if (domains.length!=0)
        {
            final Filter filter = newDomainFilter(domains);
            describeAddressesRequest.withFilters(filter);
        }

        final DescribeAddressesResult describeAddressesResult = asyncEc2Client.describeAddresses(describeAddressesRequest);

        return describeAddressesResult.getAddresses();
    }

    @NotNull
    private static Filter newDomainFilter(final String[] domains)
    {
        return new Filter("domain", Arrays.asList(domains));
    }


    @NotNull
    @Override
    public Map<String, AccountAttribute> getAccountAttributes()
    {
        return accountAttributes.get();
    }

    @Override
    public SubnetCache getSubnetCache()
    {
        return subnetCache;
    }

    @NotNull
    @Override
    public Iterable<SecurityGroup> describeSecurityGroups() throws AWSException
    {
        final DescribeSecurityGroupsResult groupDescriptions;
        try
        {
            final DescribeSecurityGroupsRequest describeSecurityGroupsRequest = new DescribeSecurityGroupsRequest();
            groupDescriptions = asyncEc2Client.describeSecurityGroups(describeSecurityGroupsRequest);
        }
        catch (final AmazonClientException exception)
        {
            throw new AWSException("Failed to query EC2 for group descriptions.", exception);
        }

        return groupDescriptions.getSecurityGroups();
    }

    @NotNull
    @Override
    public SecurityGroup newSecurityGroup(@NotNull final String name, @NotNull final String description, @Nullable VpcId vpcId) throws AWSException
    {
        String vpcIdStr = null;
        if (vpcId != null && !vpcId.isUndefined())
        {
            vpcIdStr = vpcId.getId();
        }
        try
        {
            final CreateSecurityGroupRequest createSecurityGroupRequest =
                    new CreateSecurityGroupRequest()
                            .withGroupName(name)
                            .withDescription(description)
                            .withVpcId(vpcIdStr);
            final CreateSecurityGroupResult createSecurityGroupResult = asyncEc2Client.createSecurityGroup(createSecurityGroupRequest);

            return new SecurityGroup()
                    .withGroupId(createSecurityGroupResult.getGroupId())
                    .withGroupName(name)
                    .withDescription(description)
                    .withVpcId(vpcIdStr);
        }
        catch (final AmazonClientException exception)
        {
            throw new AWSException("Failed to create EC2 security group.", exception);
        }
    }

    @Override
    public void ensureInboundTrafficIsAllowed(@NotNull final SecurityGroup group, @NotNull final Protocol protocol, @NotNull final String cidrIpRange, final int port)
    {
        AmazonEc2Utils.ensureInboundTrafficIsAllowed(asyncEc2Client, group, protocol, cidrIpRange, port);
    }

    @NotNull
    @Override
    public Map<String, KeyPairInfo> describeEc2KeyPairs(final String... keyNames)
    {
        final DescribeKeyPairsRequest describeKeyPairsRequest = new DescribeKeyPairsRequest();
        if (keyNames.length!=0)
        {
            describeKeyPairsRequest.withKeyNames(keyNames);
        }
        final DescribeKeyPairsResult keyPairInfos = asyncEc2Client.describeKeyPairs(describeKeyPairsRequest);


        return keyPairInfos.getKeyPairs().stream()
                .collect(Collectors.toMap(KeyPairInfo::getKeyName, t->t));
    }

    @NotNull
    @Override
    public KeyPair newEC2KeyPair(final String name) throws AWSException
    {
        final CreateKeyPairResult keyPairInfo;
        try
        {
            keyPairInfo = asyncEc2Client.createKeyPair( new CreateKeyPairRequest(name) );
        }
        catch (final AmazonClientException exception)
        {
            throw new AWSException("Failed to create EC2 key pair.", exception);
        }
        return keyPairInfo.getKeyPair();
    }

    @NotNull
    @Override
    public Map<String, AvailabilityZone> getAvailabilityZones() throws AWSException
    {
        return Eithers.getOrThrow(availabilityZones.get());
    }

    @NotNull
    @Override
    public Map<Vpc, Collection<Subnet>> describeVpcs() throws AWSException
    {
        final List<Subnet> subnets;
        try
        {
            subnets = asyncEc2Client.describeSubnets().getSubnets();
        }
        catch (final AmazonClientException e)
        {
            throw new AWSException("Failed to fetch the VPCs list", e);
        }

        final Map<Vpc, Collection<Subnet>> vpcsAndSubnets = new HashMap<>();
        for (final Subnet subnet : subnets)
        {
            final Vpc vpc = Iterables.getOnlyElement(vpcCache.describe(subnet.getVpcId()));
            vpcsAndSubnets.computeIfAbsent(vpc, e -> new ArrayList<>()).add(subnet);
        }
        return vpcsAndSubnets;
    }

    @Override
    public void shutdownInstance(String instanceId) throws AWSException
    {
        try
        {
            asyncEc2Client.terminateInstances(new TerminateInstancesRequest(Collections.singletonList(instanceId)));
        }
        catch (AmazonClientException e)
        {
            throw new AWSException("Error terminating elastic instance with id '" + instanceId + "'", e);
        }
    }

    @Override
    public void cancelSpotInstanceRequests(String... spotInstanceRequestIds)
    {
        final CancelSpotInstanceRequestsRequest cancelSpotInstanceRequestsRequest =
                new CancelSpotInstanceRequestsRequest().withSpotInstanceRequestIds(spotInstanceRequestIds);
        asyncEc2Client.cancelSpotInstanceRequests(cancelSpotInstanceRequestsRequest);
    }

    @Override
    public void deleteVolume(String volumeId) throws AWSException
    {
        try
        {
            asyncEc2Client.deleteVolume(new DeleteVolumeRequest(volumeId));
        }
        catch (AmazonClientException e)
        {
            throw new AWSException("Error deleting ebs volume with id '" + volumeId + "'", e);
        }
    }

    @NotNull
    @Override
    public SpotPrices getSpotPrices()
    {
        return spotPrices.get();
    }

    @NotNull
    @Override
    public AmazonEC2AsyncClient getAwsClient()
    {
        return deprecatedEc2Client;
    }

    @Override
    public AmazonEC2Async getAmazonEc2()
    {
        return asyncEc2Client;
    }

    @Override
    public Collection<Instance> describeInstances(final String... instanceIds)
    {
        return instanceCache.describe(instanceIds);
    }

    @NotNull
    @Override
    public Collection<Subnet> describeSubnets(final String... subnetIds)
    {
        return subnetCache.describe(subnetIds);
    }

    @Override
    public String getPassword(@NotNull final String privateKeyFile, @NotNull final String instanceId)
    {
        return getPassword(new File(privateKeyFile), InstanceId.from(instanceId));
    }

    @Override
    public String getPassword(@NotNull final File privateKeyFile, @NotNull final InstanceId instanceId)
    {
        return instancePasswordCache.getPassword(privateKeyFile, instanceId);
    }

    @Override
    public Collection<SpotInstanceRequest> describeSpotInstanceRequests(final String... spotInstanceRequestIds)
    {
        return spotRequestCache.describe(spotInstanceRequestIds);
    }

    @Override
    public Image describeImage(@NotNull String imageId)
    {
        return Iterables.getOnlyElement(describeImages(imageId), null);
    }

    @Override
    public List<Image> describeImages(final String... imageIds)
    {
        return imageCache.describeImages(imageIds);
    }

    @Override
    public void createTag(@NotNull final String resourceId, @NotNull final String key, @NotNull final String value)
    {
        final CreateTagsRequest createTagsRequest =
                new CreateTagsRequest().
                        withResources(resourceId).
                        withTags(new Tag(key, value));

        asyncEc2Client.createTagsAsync(createTagsRequest);
    }

    @Override
    @NotNull
    public Collection<Volume> describeVolumes() throws AWSException
    {
        try
        {
            return volumeCache.describe();
        }
        catch (Exception e)
        {
            throw new AWSException("Failed to retrieve information about EBS volumes.", e);
        }
    }

    @Override
    public String createVolume(@NotNull final String ebsSnapshotId, @NotNull final String availabilityZone)
    {
        final CreateVolumeRequest createVolumeRequest =
                new CreateVolumeRequest()
                        .withSnapshotId(ebsSnapshotId)
                        .withAvailabilityZone(availabilityZone);
        final CreateVolumeResult createVolumeResult = asyncEc2Client.createVolume(createVolumeRequest);
        return createVolumeResult.getVolume().getVolumeId();
    }

    @Override
    public void attachVolume(@NotNull final String volumeId, @NotNull final String instanceId, @NotNull final String deviceName, final boolean deleteOnTermination)
    {
        final AttachVolumeRequest attachVolumeRequest =
                new AttachVolumeRequest()
                        .withVolumeId(volumeId)
                        .withInstanceId(instanceId)
                        .withDevice(deviceName);

        asyncEc2Client.attachVolume(attachVolumeRequest);

        if (deleteOnTermination)
        {
            enableDeleteOnTermination(instanceId, volumeId, deviceName);
        }
    }

    private void enableDeleteOnTermination(@NotNull final String instanceId, @NotNull final String volumeId, @NotNull final String attachedDeviceName)
    {
        final EbsInstanceBlockDeviceSpecification ebsInstanceBlockDeviceSpecification =
                new EbsInstanceBlockDeviceSpecification()
                        .withVolumeId(volumeId)
                        .withDeleteOnTermination(true);

        final InstanceBlockDeviceMappingSpecification instanceBlockDeviceMappingSpecification =
                new InstanceBlockDeviceMappingSpecification()
                        .withDeviceName(attachedDeviceName)
                        .withEbs(ebsInstanceBlockDeviceSpecification);

        final ModifyInstanceAttributeRequest modifyInstanceAttributeRequest =
                new ModifyInstanceAttributeRequest()
                        .withBlockDeviceMappings(instanceBlockDeviceMappingSpecification)
                        .withInstanceId(instanceId);
        asyncEc2Client.modifyInstanceAttribute(modifyInstanceAttributeRequest);
    }

    @Override
    public void setMaximumEbsVolumeStatusAgeSeconds(final int maximumStatusAgeSeconds)
    {
        volumeCache.setMaximumStatusAgeSeconds(maximumStatusAgeSeconds);
    }

    @Override
    public void setMaximumInstanceStatusAgeSeconds(final int maximumStatusAgeSeconds)
    {
        supervisionIntervalSeconds = maximumStatusAgeSeconds;
        instanceCache.setMaximumStatusAgeSeconds(maximumStatusAgeSeconds);
    }

    @Override
    public void setMaximumSpotRequestStatusAgeSeconds(final int maximumStatusAgeSeconds)
    {
        spotRequestCache.setMaximumStatusAgeSeconds(maximumStatusAgeSeconds);
    }

    @Override
    public AwsSupportConstants.Region getRegion()
    {
        return region;
    }
}
