/*
 * Decompiled with CFR 0.152.
 */
package net.snowflake.client.jdbc.internal.grpc.xds;

import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import net.snowflake.client.jdbc.internal.google.common.base.MoreObjects;
import net.snowflake.client.jdbc.internal.google.common.base.Preconditions;
import net.snowflake.client.jdbc.internal.google.common.collect.HashMultiset;
import net.snowflake.client.jdbc.internal.google.common.collect.ImmutableMap;
import net.snowflake.client.jdbc.internal.google.common.primitives.UnsignedInteger;
import net.snowflake.client.jdbc.internal.grpc.Attributes;
import net.snowflake.client.jdbc.internal.grpc.ConnectivityState;
import net.snowflake.client.jdbc.internal.grpc.EquivalentAddressGroup;
import net.snowflake.client.jdbc.internal.grpc.InternalLogId;
import net.snowflake.client.jdbc.internal.grpc.LoadBalancer;
import net.snowflake.client.jdbc.internal.grpc.LoadBalancerProvider;
import net.snowflake.client.jdbc.internal.grpc.Status;
import net.snowflake.client.jdbc.internal.grpc.SynchronizationContext;
import net.snowflake.client.jdbc.internal.grpc.util.GracefulSwitchLoadBalancer;
import net.snowflake.client.jdbc.internal.grpc.util.MultiChildLoadBalancer;
import net.snowflake.client.jdbc.internal.grpc.xds.InternalXdsAttributes;
import net.snowflake.client.jdbc.internal.grpc.xds.XdsLogger;
import net.snowflake.client.jdbc.internal.grpc.xds.XdsNameResolver;
import net.snowflake.client.jdbc.internal.grpc.xds.XxHash64;
import net.snowflake.client.jdbc.internal.javax.annotation.Nullable;

final class RingHashLoadBalancer
extends MultiChildLoadBalancer {
    private static final Status RPC_HASH_NOT_FOUND = Status.INTERNAL.withDescription("RPC hash not found. Probably a bug because xds resolver config selector always generates a hash.");
    private static final XxHash64 hashFunc = XxHash64.INSTANCE;
    private final XdsLogger logger;
    private final SynchronizationContext syncContext;
    private List<RingEntry> ring;

    RingHashLoadBalancer(LoadBalancer.Helper helper) {
        super(helper);
        this.syncContext = Preconditions.checkNotNull(helper.getSynchronizationContext(), "syncContext");
        this.logger = XdsLogger.withLogId(InternalLogId.allocate("ring_hash_lb", helper.getAuthority()));
        this.logger.log(XdsLogger.XdsLogLevel.INFO, "Created");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Status acceptResolvedAddresses(LoadBalancer.ResolvedAddresses resolvedAddresses) {
        this.logger.log(XdsLogger.XdsLogLevel.DEBUG, "Received resolution result: {0}", resolvedAddresses);
        List<EquivalentAddressGroup> addrList = resolvedAddresses.getAddresses();
        Status addressValidityStatus = this.validateAddrList(addrList);
        if (!addressValidityStatus.isOk()) {
            return addressValidityStatus;
        }
        try {
            this.resolvingAddresses = true;
            MultiChildLoadBalancer.AcceptResolvedAddressRetVal acceptRetVal = super.acceptResolvedAddressesInternal(resolvedAddresses);
            if (!acceptRetVal.status.isOk()) {
                addressValidityStatus = Status.UNAVAILABLE.withDescription("Ring hash lb error: EDS resolution was successful, but was not accepted by base class (" + acceptRetVal.status + ")");
                this.handleNameResolutionError(addressValidityStatus);
                Status status = addressValidityStatus;
                return status;
            }
            RingHashConfig config = (RingHashConfig)resolvedAddresses.getLoadBalancingPolicyConfig();
            if (config == null) {
                throw new IllegalArgumentException("Missing RingHash configuration");
            }
            HashMap<EquivalentAddressGroup, Long> serverWeights = new HashMap<EquivalentAddressGroup, Long>();
            long totalWeight = 0L;
            for (EquivalentAddressGroup eag : addrList) {
                Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT);
                if (weight == null) {
                    weight = 1L;
                }
                totalWeight += weight.longValue();
                EquivalentAddressGroup addrKey = RingHashLoadBalancer.stripAttrs(eag);
                if (serverWeights.containsKey(addrKey)) {
                    serverWeights.put(addrKey, (Long)serverWeights.get(addrKey) + weight);
                    continue;
                }
                serverWeights.put(addrKey, weight);
            }
            long minWeight = (Long)Collections.min(serverWeights.values());
            double normalizedMinWeight = (double)minWeight / (double)totalWeight;
            double scale = Math.min(Math.ceil(normalizedMinWeight * (double)config.minRingSize) / normalizedMinWeight, (double)config.maxRingSize);
            this.ring = RingHashLoadBalancer.buildRing(serverWeights, totalWeight, scale);
            this.updateOverallBalancingState();
            this.shutdownRemoved(acceptRetVal.removedChildren);
        }
        finally {
            this.resolvingAddresses = false;
        }
        return Status.OK;
    }

    /*
     * Enabled aggressive block sorting
     */
    @Override
    protected void updateOverallBalancingState() {
        Preconditions.checkState(!this.getChildLbStates().isEmpty(), "no subchannel has been created");
        if (this.currentConnectivityState == ConnectivityState.SHUTDOWN) {
            this.logger.log(XdsLogger.XdsLogLevel.DEBUG, "UpdateOverallBalancingState called after shutdown");
            return;
        }
        int numIdle = 0;
        int numReady = 0;
        int numConnecting = 0;
        int numTF = 0;
        block6: for (MultiChildLoadBalancer.ChildLbState childLbState : this.getChildLbStates()) {
            ConnectivityState state = childLbState.getCurrentState();
            switch (state) {
                case READY: {
                    ++numReady;
                    break block6;
                }
                case CONNECTING: {
                    ++numConnecting;
                    break;
                }
                case IDLE: {
                    ++numIdle;
                    break;
                }
                case TRANSIENT_FAILURE: {
                    ++numTF;
                }
            }
        }
        ConnectivityState overallState = numReady > 0 ? ConnectivityState.READY : (numTF >= 2 ? ConnectivityState.TRANSIENT_FAILURE : (numConnecting > 0 ? ConnectivityState.CONNECTING : (numTF == 1 && this.getChildLbStates().size() > 1 ? ConnectivityState.CONNECTING : (numIdle > 0 ? ConnectivityState.IDLE : ConnectivityState.TRANSIENT_FAILURE))));
        RingHashPicker picker = new RingHashPicker(this.syncContext, this.ring, this.getImmutableChildMap());
        this.getHelper().updateBalancingState(overallState, picker);
        this.currentConnectivityState = overallState;
    }

    @Override
    protected boolean reconnectOnIdle() {
        return false;
    }

    @Override
    protected boolean reactivateChildOnReuse() {
        return false;
    }

    @Override
    protected MultiChildLoadBalancer.ChildLbState createChildLbState(Object key, Object policyConfig, LoadBalancer.SubchannelPicker initialPicker, LoadBalancer.ResolvedAddresses resolvedAddresses) {
        return new RingHashChildLbState((MultiChildLoadBalancer.Endpoint)key, this.getChildAddresses(key, resolvedAddresses, null));
    }

    private Status validateAddrList(List<EquivalentAddressGroup> addrList) {
        if (addrList.isEmpty()) {
            Status unavailableStatus = Status.UNAVAILABLE.withDescription("Ring hash lb error: EDS resolution was successful, but returned server addresses are empty.");
            this.handleNameResolutionError(unavailableStatus);
            return unavailableStatus;
        }
        String dupAddrString = this.validateNoDuplicateAddresses(addrList);
        if (dupAddrString != null) {
            Status unavailableStatus = Status.UNAVAILABLE.withDescription("Ring hash lb error: EDS resolution was successful, but there were duplicate addresses: " + dupAddrString);
            this.handleNameResolutionError(unavailableStatus);
            return unavailableStatus;
        }
        long totalWeight = 0L;
        for (EquivalentAddressGroup eag : addrList) {
            Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT);
            if (weight == null) {
                weight = 1L;
            }
            if (weight < 0L) {
                Status unavailableStatus = Status.UNAVAILABLE.withDescription(String.format("Ring hash lb error: EDS resolution was successful, but returned a negative weight for %s.", RingHashLoadBalancer.stripAttrs(eag)));
                this.handleNameResolutionError(unavailableStatus);
                return unavailableStatus;
            }
            if (weight > UnsignedInteger.MAX_VALUE.longValue()) {
                Status unavailableStatus = Status.UNAVAILABLE.withDescription(String.format("Ring hash lb error: EDS resolution was successful, but returned a weight too large to fit in an unsigned int for %s.", RingHashLoadBalancer.stripAttrs(eag)));
                this.handleNameResolutionError(unavailableStatus);
                return unavailableStatus;
            }
            totalWeight += weight.longValue();
        }
        if (totalWeight > UnsignedInteger.MAX_VALUE.longValue()) {
            Status unavailableStatus = Status.UNAVAILABLE.withDescription(String.format("Ring hash lb error: EDS resolution was successful, but returned a sum of weights too large to fit in an unsigned int (%d).", totalWeight));
            this.handleNameResolutionError(unavailableStatus);
            return unavailableStatus;
        }
        return Status.OK;
    }

    @Nullable
    private String validateNoDuplicateAddresses(List<EquivalentAddressGroup> addrList) {
        HashSet<SocketAddress> addresses = new HashSet<SocketAddress>();
        HashMultiset<String> dups = HashMultiset.create();
        for (EquivalentAddressGroup eag : addrList) {
            for (SocketAddress address : eag.getAddresses()) {
                if (addresses.add(address)) continue;
                dups.add(address.toString());
            }
        }
        if (!dups.isEmpty()) {
            return dups.entrySet().stream().map(dup -> String.format("Address: %s, count: %d", dup.getElement(), dup.getCount() + 1)).collect(Collectors.joining("; "));
        }
        return null;
    }

    private static List<RingEntry> buildRing(Map<EquivalentAddressGroup, Long> serverWeights, long totalWeight, double scale) {
        ArrayList<RingEntry> ring = new ArrayList<RingEntry>();
        double currentHashes = 0.0;
        double targetHashes = 0.0;
        for (Map.Entry<EquivalentAddressGroup, Long> entry : serverWeights.entrySet()) {
            MultiChildLoadBalancer.Endpoint endpoint = new MultiChildLoadBalancer.Endpoint(entry.getKey());
            double normalizedWeight = (double)entry.getValue().longValue() / (double)totalWeight;
            StringBuilder sb = new StringBuilder(entry.getKey().getAddresses().get(0).toString());
            sb.append('_');
            int lengthWithoutCounter = sb.length();
            targetHashes += scale * normalizedWeight;
            long i = 0L;
            while (currentHashes < targetHashes) {
                sb.append(i);
                long hash = hashFunc.hashAsciiString(sb.toString());
                ring.add(new RingEntry(hash, endpoint));
                ++i;
                currentHashes += 1.0;
                sb.setLength(lengthWithoutCounter);
            }
        }
        Collections.sort(ring);
        return Collections.unmodifiableList(ring);
    }

    public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
        if (eag.getAttributes() == Attributes.EMPTY) {
            return eag;
        }
        return new EquivalentAddressGroup(eag.getAddresses());
    }

    @Override
    protected LoadBalancer.SubchannelPicker getSubchannelPicker(Map<Object, LoadBalancer.SubchannelPicker> childPickers) {
        throw new UnsupportedOperationException("Not used by RingHash");
    }

    static Set<EquivalentAddressGroup> getStrippedChildEags(Collection<MultiChildLoadBalancer.ChildLbState> states) {
        return states.stream().map(MultiChildLoadBalancer.ChildLbState::getEag).map(RingHashLoadBalancer::stripAttrs).collect(Collectors.toSet());
    }

    @Override
    protected Collection<MultiChildLoadBalancer.ChildLbState> getChildLbStates() {
        return super.getChildLbStates();
    }

    @Override
    protected MultiChildLoadBalancer.ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) {
        return super.getChildLbStateEag(eag);
    }

    class RingHashChildLbState
    extends MultiChildLoadBalancer.ChildLbState {
        public RingHashChildLbState(MultiChildLoadBalancer.Endpoint key, LoadBalancer.ResolvedAddresses resolvedAddresses) {
            super(key, RingHashLoadBalancer.this.pickFirstLbProvider, null, LoadBalancer.EMPTY_PICKER, resolvedAddresses, true);
        }

        @Override
        protected void reactivate(LoadBalancerProvider policyProvider) {
            if (!this.isDeactivated()) {
                return;
            }
            RingHashLoadBalancer.this.currentConnectivityState = ConnectivityState.CONNECTING;
            this.getLb().switchTo(RingHashLoadBalancer.this.pickFirstLbProvider);
            this.markReactivated();
            this.getLb().acceptResolvedAddresses(this.getResolvedAddresses());
            RingHashLoadBalancer.this.logger.log(XdsLogger.XdsLogLevel.DEBUG, "Child balancer {0} reactivated", this.getKey());
        }

        public void activate() {
            this.reactivate(RingHashLoadBalancer.this.pickFirstLbProvider);
        }

        @Override
        protected void shutdown() {
            super.shutdown();
        }

        @Override
        protected GracefulSwitchLoadBalancer getLb() {
            return super.getLb();
        }
    }

    static final class RingHashConfig {
        final long minRingSize;
        final long maxRingSize;

        RingHashConfig(long minRingSize, long maxRingSize) {
            Preconditions.checkArgument(minRingSize > 0L, "minRingSize <= 0");
            Preconditions.checkArgument(maxRingSize > 0L, "maxRingSize <= 0");
            Preconditions.checkArgument(minRingSize <= maxRingSize, "minRingSize > maxRingSize");
            this.minRingSize = minRingSize;
            this.maxRingSize = maxRingSize;
        }

        public String toString() {
            return MoreObjects.toStringHelper(this).add("minRingSize", this.minRingSize).add("maxRingSize", this.maxRingSize).toString();
        }
    }

    private static final class RingEntry
    implements Comparable<RingEntry> {
        private final long hash;
        private final MultiChildLoadBalancer.Endpoint addrKey;

        private RingEntry(long hash, MultiChildLoadBalancer.Endpoint addrKey) {
            this.hash = hash;
            this.addrKey = addrKey;
        }

        @Override
        public int compareTo(RingEntry entry) {
            return Long.compare(this.hash, entry.hash);
        }
    }

    private static final class SubchannelView {
        private final RingHashChildLbState childLbState;
        private final ConnectivityState connectivityState;

        private SubchannelView(RingHashChildLbState childLbState, ConnectivityState state) {
            this.childLbState = childLbState;
            this.connectivityState = state;
        }
    }

    private static final class RingHashPicker
    extends LoadBalancer.SubchannelPicker {
        private final SynchronizationContext syncContext;
        private final List<RingEntry> ring;
        private final Map<MultiChildLoadBalancer.Endpoint, SubchannelView> pickableSubchannels;

        private RingHashPicker(SynchronizationContext syncContext, List<RingEntry> ring, ImmutableMap<Object, MultiChildLoadBalancer.ChildLbState> subchannels) {
            this.syncContext = syncContext;
            this.ring = ring;
            this.pickableSubchannels = new HashMap<MultiChildLoadBalancer.Endpoint, SubchannelView>(subchannels.size());
            for (Map.Entry entry : subchannels.entrySet()) {
                RingHashChildLbState childLbState = (RingHashChildLbState)entry.getValue();
                this.pickableSubchannels.put((MultiChildLoadBalancer.Endpoint)entry.getKey(), new SubchannelView(childLbState, childLbState.getCurrentState()));
            }
        }

        private int getTargetIndex(Long requestHash) {
            if (this.ring.size() <= 1) {
                return 0;
            }
            int low = 0;
            int high = this.ring.size() - 1;
            int mid = (low + high) / 2;
            do {
                long midValL;
                long midVal = this.ring.get(mid).hash;
                long l = midValL = mid == 0 ? 0L : this.ring.get(mid - 1).hash;
                if (requestHash <= midVal && requestHash > midValL) break;
                if (midVal < requestHash) {
                    low = mid + 1;
                    continue;
                }
                high = mid - 1;
            } while ((mid = (low + high) / 2) < this.ring.size() && low <= high);
            return mid;
        }

        @Override
        public LoadBalancer.PickResult pickSubchannel(LoadBalancer.PickSubchannelArgs args) {
            Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY);
            if (requestHash == null) {
                return LoadBalancer.PickResult.withError(RPC_HASH_NOT_FOUND);
            }
            int targetIndex = this.getTargetIndex(requestHash);
            for (int i = 0; i < this.ring.size(); ++i) {
                int index = (targetIndex + i) % this.ring.size();
                SubchannelView subchannelView = this.pickableSubchannels.get(this.ring.get(index).addrKey);
                RingHashChildLbState childLbState = subchannelView.childLbState;
                if (subchannelView.connectivityState == ConnectivityState.READY) {
                    return childLbState.getCurrentPicker().pickSubchannel(args);
                }
                if (subchannelView.connectivityState == ConnectivityState.CONNECTING) {
                    return LoadBalancer.PickResult.withNoResult();
                }
                if (subchannelView.connectivityState != ConnectivityState.IDLE) continue;
                this.syncContext.execute(() -> {
                    if (childLbState.isDeactivated()) {
                        childLbState.activate();
                    } else {
                        childLbState.getLb().requestConnection();
                    }
                });
                return LoadBalancer.PickResult.withNoResult();
            }
            RingHashChildLbState originalSubchannel = this.pickableSubchannels.get(this.ring.get(targetIndex).addrKey).childLbState;
            return originalSubchannel.getCurrentPicker().pickSubchannel(args);
        }
    }
}

