package com.atlassian.aws.ec2;

import com.amazonaws.services.ec2.model.Subnet;
import com.atlassian.aws.AWSAccount;
import com.atlassian.aws.ec2.model.AvailabilityZoneId;
import com.atlassian.aws.ec2.model.ResourceId;
import com.atlassian.aws.ec2.model.SubnetId;
import com.google.common.base.Function;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimaps;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

public class SubnetChooser
{
    private static final Logger log = Logger.getLogger(SubnetChooser.class);
    private final AvailabilityZoneChooser availabilityZoneChooser;
    private final AWSAccount awsAccount;

    public SubnetChooser(final AvailabilityZoneChooser availabilityZoneChooser, final AWSAccount awsAccount)
    {
        this.availabilityZoneChooser = availabilityZoneChooser;
        this.awsAccount = awsAccount;
    }

    @Nullable
    public SubnetId choose(final Iterable<SubnetId> subnets, final EC2InstanceType instanceType)
    {
        if (Iterables.isEmpty(subnets))
        {
            return null;
        }
        final ImmutableMultimap<AvailabilityZoneId, Subnet> availabilityZones = toAvailabilityZones(awsAccount.getSubnetCache().describe(ResourceId.getIds(subnets)));
        final AvailabilityZoneId chosenAz = availabilityZoneChooser.choose(availabilityZones.keySet(), instanceType);

        return SubnetId.from(Iterables.get(availabilityZones.get(chosenAz), 0).getSubnetId());
    }

    public void blacklist(@NotNull final SubnetId subnet, @NotNull final EC2InstanceType instanceType)
    {
        final AvailabilityZoneId availabilityZone = toAvailabilityZone(subnet);

        log.info("Adding subnet " + subnet + "/" + availabilityZone + " to blacklist");
        availabilityZoneChooser.blacklist(availabilityZone, instanceType);
    }

    @NotNull
    private AvailabilityZoneId toAvailabilityZone(@NotNull final SubnetId subnetId)
    {
        final Subnet subnet = Iterables.getOnlyElement(awsAccount.getSubnetCache().describe(subnetId.getId()));
        return AvailabilityZoneId.from(subnet.getAvailabilityZone());
    }

    private ImmutableListMultimap<AvailabilityZoneId, Subnet> toAvailabilityZones(final Iterable<Subnet> subnets)
    {
        return Multimaps.index(subnets, new Function<Subnet, AvailabilityZoneId>()
        {
            @Override
            public AvailabilityZoneId apply(final Subnet subnet)
            {
                return AvailabilityZoneId.from(subnet.getAvailabilityZone());
            }
        });
    }
}
