/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.service;

import java.lang.management.ManagementFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.LinkedBlockingQueue;
import javax.management.MBeanServer;
import javax.management.ObjectName;
import org.apache.cassandra.service.PBSPredictionResult;
import org.apache.cassandra.service.PBSPredictorMBean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PBSPredictor
implements PBSPredictorMBean {
    private static final Logger logger = LoggerFactory.getLogger(PBSPredictor.class);
    public static final String MBEAN_NAME = "org.apache.cassandra.service:type=PBSPredictor";
    private static final boolean DEFAULT_DO_LOG_LATENCIES = false;
    private static final int DEFAULT_MAX_LOGGED_LATENCIES = 10000;
    private static final int DEFAULT_NUMBER_TRIALS_PREDICTION = 10000;
    private final Queue<String> writeMessageIds = new LinkedBlockingQueue<String>();
    private final Queue<String> readMessageIds = new LinkedBlockingQueue<String>();
    private final Map<String, MessageLatencyCollection> messageIdToWriteLats = new ConcurrentHashMap<String, MessageLatencyCollection>();
    private final Map<String, MessageLatencyCollection> messageIdToReadLats = new ConcurrentHashMap<String, MessageLatencyCollection>();
    private Random random;
    private boolean initialized = false;
    private boolean logLatencies = false;
    private int maxLoggedLatencies = 10000;
    private int numberTrialsPrediction = 10000;
    private static final PBSPredictor instance = new PBSPredictor();

    public static PBSPredictor instance() {
        return instance;
    }

    private PBSPredictor() {
        this.init();
    }

    @Override
    public void enableConsistencyPredictionLogging() {
        this.logLatencies = true;
    }

    @Override
    public void disableConsistencyPredictionLogging() {
        this.logLatencies = false;
    }

    public boolean isLoggingEnabled() {
        return this.logLatencies;
    }

    @Override
    public void setMaxLoggedLatenciesForConsistencyPrediction(int maxLogged) {
        this.maxLoggedLatencies = maxLogged;
    }

    @Override
    public void setNumberTrialsForConsistencyPrediction(int numTrials) {
        this.numberTrialsPrediction = numTrials;
    }

    public void init() {
        if (!this.initialized) {
            this.random = new Random();
            MBeanServer mbs = ManagementFactory.getPlatformMBeanServer();
            try {
                mbs.registerMBean(this, new ObjectName(MBEAN_NAME));
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            this.initialized = true;
        }
    }

    private long getRandomElement(List<Long> list) {
        if (list.size() == 0) {
            throw new RuntimeException("Not enough data for prediction");
        }
        return list.get(this.random.nextInt(list.size()));
    }

    private float listAverage(List<Long> list) {
        long accum = 0L;
        for (long value : list) {
            accum += value;
        }
        return (float)accum / (float)list.size();
    }

    private long getPercentile(List<Long> list, float percentile) {
        Collections.sort(list);
        return list.get((int)((float)list.size() * percentile));
    }

    private long getRandomLatencySample(Map<Integer, List<Long>> samples, int replicaNumber) {
        if (samples.containsKey(replicaNumber)) {
            return this.getRandomElement(samples.get(replicaNumber));
        }
        return this.getRandomElement(samples.get(samples.keySet().toArray()[this.random.nextInt(samples.keySet().size())]));
    }

    @Override
    public PBSPredictionResult doPrediction(int n, int r, int w, float timeSinceWrite, int numberVersionsStale, float percentileLatency) {
        if (r > n) {
            throw new IllegalArgumentException("r must be less than n");
        }
        if (r < 0) {
            throw new IllegalArgumentException("r must be positive");
        }
        if (w > n) {
            throw new IllegalArgumentException("w must be less than n");
        }
        if (w < 0) {
            throw new IllegalArgumentException("w must be positive");
        }
        if (percentileLatency < 0.0f || percentileLatency > 1.0f) {
            throw new IllegalArgumentException("percentileLatency must be between 0 and 1 inclusive");
        }
        if (numberVersionsStale < 0) {
            throw new IllegalArgumentException("numberVersionsStale must be positive");
        }
        if (!this.logLatencies) {
            throw new IllegalStateException("Latency logging is not enabled");
        }
        Map<Integer, List<Long>> wLatencies = this.getOrderedWLatencies();
        Map<Integer, List<Long>> aLatencies = this.getOrderedALatencies();
        Map<Integer, List<Long>> rLatencies = this.getOrderedRLatencies();
        Map<Integer, List<Long>> sLatencies = this.getOrderedSLatencies();
        if (wLatencies.isEmpty() || aLatencies.isEmpty()) {
            throw new IllegalStateException("No write latencies have been recorded so far. Run some (non-local) inserts.");
        }
        if (rLatencies.isEmpty() || sLatencies.isEmpty()) {
            throw new IllegalStateException("No read latencies have been recorded so far. Run some (non-local) reads.");
        }
        ArrayList<Long> readLatencies = new ArrayList<Long>();
        ArrayList<Long> writeLatencies = new ArrayList<Long>();
        long consistentReads = 0L;
        ArrayList<Long> trialWLatencies = new ArrayList<Long>();
        ArrayList<Long> trialRLatencies = new ArrayList<Long>();
        ArrayList<Long> replicaWriteLatencies = new ArrayList<Long>();
        ArrayList<Long> replicaReadLatencies = new ArrayList<Long>();
        for (int i = 0; i < this.numberTrialsPrediction; ++i) {
            int replicaNo;
            for (replicaNo = 0; replicaNo < n; ++replicaNo) {
                long trialWLatency = this.getRandomLatencySample(wLatencies, replicaNo);
                long trialALatency = this.getRandomLatencySample(aLatencies, replicaNo);
                trialWLatencies.add(trialWLatency);
                replicaWriteLatencies.add(trialWLatency + trialALatency);
            }
            for (replicaNo = 0; replicaNo < r; ++replicaNo) {
                long trialRLatency = this.getRandomLatencySample(rLatencies, replicaNo);
                long trialSLatency = this.getRandomLatencySample(sLatencies, replicaNo);
                trialRLatencies.add(trialRLatency);
                replicaReadLatencies.add(trialRLatency + trialSLatency);
            }
            Collections.sort(replicaWriteLatencies);
            long writeLatency = (Long)replicaWriteLatencies.get(w - 1);
            writeLatencies.add(writeLatency);
            ArrayList sortedReplicaReadLatencies = new ArrayList(replicaReadLatencies);
            Collections.sort(sortedReplicaReadLatencies);
            readLatencies.add((Long)sortedReplicaReadLatencies.get(r - 1));
            for (int responseNumber = 0; responseNumber < r; ++responseNumber) {
                int replicaNumber = replicaReadLatencies.indexOf(sortedReplicaReadLatencies.get(responseNumber));
                if ((float)writeLatency + timeSinceWrite + (float)((Long)trialRLatencies.get(replicaNumber)).longValue() >= (float)((Long)trialWLatencies.get(replicaNumber)).longValue()) {
                    ++consistentReads;
                    break;
                }
                replicaReadLatencies.set(replicaNumber, -1L);
            }
            trialWLatencies.clear();
            trialRLatencies.clear();
            replicaReadLatencies.clear();
            replicaWriteLatencies.clear();
        }
        float oneVersionConsistencyProbability = (float)consistentReads / (float)this.numberTrialsPrediction;
        float consistencyProbability = (float)(1.0 - Math.pow(1.0f - oneVersionConsistencyProbability, numberVersionsStale));
        float averageWriteLatency = this.listAverage(writeLatencies);
        float averageReadLatency = this.listAverage(readLatencies);
        long percentileWriteLatency = this.getPercentile(writeLatencies, percentileLatency);
        long percentileReadLatency = this.getPercentile(readLatencies, percentileLatency);
        return new PBSPredictionResult(n, r, w, timeSinceWrite, numberVersionsStale, consistencyProbability, averageReadLatency, averageWriteLatency, percentileReadLatency, percentileLatency, percentileWriteLatency, percentileLatency);
    }

    public void startWriteOperation(String id) {
        if (!this.logLatencies) {
            return;
        }
        this.startWriteOperation(id, System.currentTimeMillis());
    }

    public void startWriteOperation(String id, long startTime) {
        if (!this.logLatencies) {
            return;
        }
        assert (!this.messageIdToWriteLats.containsKey(id));
        this.writeMessageIds.add(id);
        if (this.writeMessageIds.size() > this.maxLoggedLatencies) {
            String toEvict = this.writeMessageIds.remove();
            this.messageIdToWriteLats.remove(toEvict);
        }
        this.messageIdToWriteLats.put(id, new MessageLatencyCollection(startTime));
    }

    public void startReadOperation(String id) {
        if (!this.logLatencies) {
            return;
        }
        this.startReadOperation(id, System.currentTimeMillis());
    }

    public void startReadOperation(String id, long startTime) {
        if (!this.logLatencies) {
            return;
        }
        assert (!this.messageIdToReadLats.containsKey(id));
        this.readMessageIds.add(id);
        if (this.readMessageIds.size() > this.maxLoggedLatencies) {
            String toEvict = this.readMessageIds.remove();
            this.messageIdToReadLats.remove(toEvict);
        }
        this.messageIdToReadLats.put(id, new MessageLatencyCollection(startTime));
    }

    public void logWriteResponse(String id, long constructionTime) {
        if (!this.logLatencies) {
            return;
        }
        this.logWriteResponse(id, constructionTime, System.currentTimeMillis());
    }

    public void logWriteResponse(String id, long responseCreationTime, long receivedTime) {
        if (!this.logLatencies) {
            return;
        }
        MessageLatencyCollection writeLatsCollection = this.messageIdToWriteLats.get(id);
        if (writeLatsCollection == null) {
            return;
        }
        Long startTime = writeLatsCollection.getStartTime();
        writeLatsCollection.addSendLat(Math.max(0L, responseCreationTime - startTime));
        writeLatsCollection.addReplyLat(Math.max(0L, receivedTime - responseCreationTime));
    }

    public void logReadResponse(String id, long constructionTime) {
        if (!this.logLatencies) {
            return;
        }
        this.logReadResponse(id, constructionTime, System.currentTimeMillis());
    }

    public void logReadResponse(String id, long responseCreationTime, long receivedTime) {
        if (!this.logLatencies) {
            return;
        }
        MessageLatencyCollection readLatsCollection = this.messageIdToReadLats.get(id);
        if (readLatsCollection == null) {
            return;
        }
        Long startTime = readLatsCollection.getStartTime();
        readLatsCollection.addSendLat(Math.max(0L, responseCreationTime - startTime));
        readLatsCollection.addReplyLat(Math.max(0L, receivedTime - responseCreationTime));
    }

    Map<Integer, List<Long>> getOrderedWLatencies() {
        ArrayList<Collection<Long>> allWLatencies = new ArrayList<Collection<Long>>();
        for (MessageLatencyCollection wlc : this.messageIdToWriteLats.values()) {
            allWLatencies.add(wlc.getSendLats());
        }
        return this.getOrderedLatencies(allWLatencies);
    }

    Map<Integer, List<Long>> getOrderedALatencies() {
        ArrayList<Collection<Long>> allALatencies = new ArrayList<Collection<Long>>();
        for (MessageLatencyCollection wlc : this.messageIdToWriteLats.values()) {
            allALatencies.add(wlc.getReplyLats());
        }
        return this.getOrderedLatencies(allALatencies);
    }

    Map<Integer, List<Long>> getOrderedRLatencies() {
        ArrayList<Collection<Long>> allRLatencies = new ArrayList<Collection<Long>>();
        for (MessageLatencyCollection rlc : this.messageIdToReadLats.values()) {
            allRLatencies.add(rlc.getSendLats());
        }
        return this.getOrderedLatencies(allRLatencies);
    }

    Map<Integer, List<Long>> getOrderedSLatencies() {
        ArrayList<Collection<Long>> allSLatencies = new ArrayList<Collection<Long>>();
        for (MessageLatencyCollection rlc : this.messageIdToReadLats.values()) {
            allSLatencies.add(rlc.getReplyLats());
        }
        return this.getOrderedLatencies(allSLatencies);
    }

    private Map<Integer, List<Long>> getOrderedLatencies(Collection<Collection<Long>> latencyLists) {
        HashMap<Integer, List<Long>> ret = new HashMap<Integer, List<Long>>();
        int maxResponses = 0;
        for (Collection<Long> latencies : latencyLists) {
            int i;
            ArrayList<Long> sortedLatencies = new ArrayList<Long>(latencies);
            Collections.sort(sortedLatencies);
            if (sortedLatencies.size() > maxResponses) {
                for (i = maxResponses + 1; i <= sortedLatencies.size(); ++i) {
                    ret.put(i, new Vector());
                }
                maxResponses = sortedLatencies.size();
            }
            for (i = 1; i <= sortedLatencies.size(); ++i) {
                ((List)ret.get(i)).add(sortedLatencies.get(i - 1));
            }
        }
        return ret;
    }

    private class MessageLatencyCollection {
        Long startTime;
        Collection<Long> sendLats;
        Collection<Long> replyLats;

        MessageLatencyCollection(Long startTime) {
            this.startTime = startTime;
            this.sendLats = new ConcurrentLinkedQueue<Long>();
            this.replyLats = new ConcurrentLinkedQueue<Long>();
        }

        void addSendLat(Long sendLat) {
            this.sendLats.add(sendLat);
        }

        void addReplyLat(Long replyLat) {
            this.replyLats.add(replyLat);
        }

        Collection<Long> getSendLats() {
            return this.sendLats;
        }

        Collection<Long> getReplyLats() {
            return this.replyLats;
        }

        Long getStartTime() {
            return this.startTime;
        }
    }
}

