/*
 * Decompiled with CFR 0.152.
 */
package io.rsocket.loadbalance;

import io.rsocket.RSocket;
import io.rsocket.core.RSocketConnector;
import io.rsocket.loadbalance.ClientLoadbalanceStrategy;
import io.rsocket.loadbalance.WeightedStats;
import io.rsocket.loadbalance.WeightedStatsRequestInterceptor;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;
import reactor.util.annotation.Nullable;

public class WeightedLoadbalanceStrategy
implements ClientLoadbalanceStrategy {
    private static final double EXP_FACTOR = 4.0;
    final int maxPairSelectionAttempts;
    final Function<RSocket, WeightedStats> weightedStatsResolver;

    private WeightedLoadbalanceStrategy(int numberOfAttempts, @Nullable Function<RSocket, WeightedStats> resolver) {
        this.maxPairSelectionAttempts = numberOfAttempts;
        this.weightedStatsResolver = resolver != null ? resolver : new DefaultWeightedStatsResolver();
    }

    @Override
    public void initialize(RSocketConnector connector) {
        Function<RSocket, WeightedStats> resolver = this.weightedStatsResolver;
        if (resolver instanceof DefaultWeightedStatsResolver) {
            ((DefaultWeightedStatsResolver)resolver).init(connector);
        }
    }

    @Override
    public RSocket select(List<RSocket> sockets) {
        RSocket weightedRSocket;
        int size = sockets.size();
        Function<RSocket, WeightedStats> weightedStatsResolver = this.weightedStatsResolver;
        switch (size) {
            case 1: {
                weightedRSocket = sockets.get(0);
                break;
            }
            case 2: {
                RSocket rsc1 = sockets.get(0);
                RSocket rsc2 = sockets.get(1);
                double w1 = WeightedLoadbalanceStrategy.algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1));
                double w2 = WeightedLoadbalanceStrategy.algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2));
                if (w1 < w2) {
                    weightedRSocket = rsc2;
                    break;
                }
                weightedRSocket = rsc1;
                break;
            }
            default: {
                RSocket rsc1 = null;
                RSocket rsc2 = null;
                for (int i = 0; i < this.maxPairSelectionAttempts; ++i) {
                    int i1 = ThreadLocalRandom.current().nextInt(size);
                    int i2 = ThreadLocalRandom.current().nextInt(size - 1);
                    if (i2 >= i1) {
                        ++i2;
                    }
                    rsc1 = sockets.get(i1);
                    rsc2 = sockets.get(i2);
                    if (rsc1.availability() > 0.0 && rsc2.availability() > 0.0) break;
                }
                if (rsc1 != null & rsc2 != null) {
                    double w2;
                    double w1 = WeightedLoadbalanceStrategy.algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1));
                    if (w1 < (w2 = WeightedLoadbalanceStrategy.algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)))) {
                        weightedRSocket = rsc2;
                        break;
                    }
                    weightedRSocket = rsc1;
                    break;
                }
                weightedRSocket = rsc1 != null ? rsc1 : rsc2;
            }
        }
        return weightedRSocket;
    }

    private static double algorithmicWeight(RSocket rSocket, @Nullable WeightedStats weightedStats) {
        if (weightedStats == null || rSocket.isDisposed() || rSocket.availability() == 0.0) {
            return 0.0;
        }
        int pending = weightedStats.pending();
        double latency = weightedStats.predictedLatency();
        double low = weightedStats.lowerQuantileLatency();
        double high = Math.max(weightedStats.higherQuantileLatency(), low * 1.001);
        double bandWidth = Math.max(high - low, 1.0);
        if (latency < low) {
            latency /= WeightedLoadbalanceStrategy.calculateFactor(low, latency, bandWidth);
        } else if (latency > high) {
            latency *= WeightedLoadbalanceStrategy.calculateFactor(latency, high, bandWidth);
        }
        return rSocket.availability() * weightedStats.weightedAvailability() / (1.0 + latency * (double)(pending + 1));
    }

    private static double calculateFactor(double u, double l, double bandWidth) {
        double alpha = (u - l) / bandWidth;
        return Math.pow(1.0 + alpha, 4.0);
    }

    public static WeightedLoadbalanceStrategy create() {
        return new Builder().build();
    }

    public static Builder builder() {
        return new Builder();
    }

    private static class DefaultWeightedStatsResolver
    implements Function<RSocket, WeightedStats> {
        final Map<RSocket, WeightedStats> statsMap = new ConcurrentHashMap<RSocket, WeightedStats>();

        private DefaultWeightedStatsResolver() {
        }

        @Override
        public WeightedStats apply(RSocket rSocket) {
            return this.statsMap.get(rSocket);
        }

        void init(RSocketConnector connector) {
            connector.interceptors(registry -> registry.forRequestsInRequester(rSocket -> {
                WeightedStatsRequestInterceptor interceptor = new WeightedStatsRequestInterceptor((RSocket)rSocket){
                    final /* synthetic */ RSocket val$rSocket;
                    {
                        this.val$rSocket = rSocket;
                    }

                    @Override
                    public void dispose() {
                        statsMap.remove(this.val$rSocket);
                    }
                };
                this.statsMap.put((RSocket)rSocket, interceptor);
                return interceptor;
            }));
        }
    }

    public static class Builder {
        private int maxPairSelectionAttempts = 5;
        @Nullable
        private Function<RSocket, WeightedStats> weightedStatsResolver;

        private Builder() {
        }

        public Builder maxPairSelectionAttempts(int numberOfAttempts) {
            this.maxPairSelectionAttempts = numberOfAttempts;
            return this;
        }

        public Builder weightedStatsResolver(Function<RSocket, WeightedStats> resolver) {
            this.weightedStatsResolver = resolver;
            return this;
        }

        public WeightedLoadbalanceStrategy build() {
            return new WeightedLoadbalanceStrategy(this.maxPairSelectionAttempts, this.weightedStatsResolver);
        }
    }
}

