/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.parallelconsumer.state;

import io.confluent.csid.utils.LoopingResumingIterator;
import io.confluent.parallelconsumer.ParallelConsumerOptions;
import io.confluent.parallelconsumer.internal.PCModule;
import io.confluent.parallelconsumer.metrics.PCMetrics;
import io.confluent.parallelconsumer.metrics.PCMetricsDef;
import io.confluent.parallelconsumer.state.ProcessingShard;
import io.confluent.parallelconsumer.state.ShardKey;
import io.confluent.parallelconsumer.state.WorkContainer;
import io.confluent.parallelconsumer.state.WorkManager;
import io.micrometer.core.instrument.Gauge;
import io.micrometer.core.instrument.Tag;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Optional;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.common.TopicPartition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShardManager<K, V> {
    private static final Logger log = LoggerFactory.getLogger(ShardManager.class);
    private final PCModule<K, V> module;
    private final ParallelConsumerOptions<?, ?> options;
    private final WorkManager<K, V> wm;
    private final Map<ShardKey, ProcessingShard<K, V>> processingShards = new ConcurrentHashMap<ShardKey, ProcessingShard<K, V>>();
    private final Comparator<WorkContainer<?, ?>> retryQueueWorkContainerComparator = Comparator.comparing(workContainer -> workContainer.getRetryDueAt()).thenComparing(workContainer -> {
        TopicPartition tp = workContainer.getTopicPartition();
        return tp.topic() + tp.partition();
    }).thenComparing(WorkContainer::offset);
    private final NavigableSet<WorkContainer<?, ?>> retryQueue = new TreeSet(this.retryQueueWorkContainerComparator);
    private Optional<ShardKey> iterationResumePoint = Optional.empty();
    private Gauge shardsSizeGauge;
    private Gauge numberOfShardsGauge;
    private final PCMetrics pcMetrics;

    public ShardManager(PCModule<K, V> module, WorkManager<K, V> wm) {
        this.module = module;
        this.wm = wm;
        this.options = module.options();
        this.pcMetrics = module.pcMetrics();
        this.initMetrics();
    }

    Optional<ProcessingShard<K, V>> getShard(ShardKey key) {
        return Optional.ofNullable(this.processingShards.get(key));
    }

    ShardKey computeShardKey(WorkContainer<?, ?> wc) {
        return ShardKey.of(wc, this.options.getOrdering());
    }

    ShardKey computeShardKey(ConsumerRecord<?, ?> wc) {
        return ShardKey.of(wc, this.options.getOrdering());
    }

    public long getNumberOfWorkQueuedInShardsAwaitingSelection() {
        return this.processingShards.values().stream().mapToLong(ProcessingShard::getCountOfWorkAwaitingSelection).sum();
    }

    public boolean workIsWaitingToBeProcessed() {
        Collection<ProcessingShard<K, V>> allShards = this.processingShards.values();
        return allShards.parallelStream().anyMatch(ProcessingShard::workIsWaitingToBeProcessed);
    }

    void removeAnyShardEntriesReferencedFrom(Collection<Optional<ConsumerRecord<K, V>>> recordsFromRemovedPartition) {
        List polledRecordsFromPartition = recordsFromRemovedPartition.stream().filter(Optional::isPresent).map(Optional::get).collect(Collectors.toList());
        for (ConsumerRecord consumerRecord : polledRecordsFromPartition) {
            this.removeWorkFromShardFor(consumerRecord);
        }
    }

    private void removeWorkFromShardFor(ConsumerRecord<K, V> consumerRecord) {
        ShardKey shardKey = this.computeShardKey(consumerRecord);
        if (this.processingShards.containsKey(shardKey)) {
            ProcessingShard<K, V> shard = this.processingShards.get(shardKey);
            WorkContainer<K, V> removedWC = shard.remove(consumerRecord.offset());
            this.retryQueue.remove(removedWC);
            this.removeShardIfEmpty(shardKey);
        } else {
            log.trace("Shard referenced by WC: {} with shard key: {} already removed", consumerRecord, (Object)shardKey);
        }
    }

    public void addWorkContainer(long epochOfInboundRecords, ConsumerRecord<K, V> aRecord) {
        WorkContainer<K, V> wc = new WorkContainer<K, V>(epochOfInboundRecords, aRecord, this.module);
        ShardKey shardKey = this.computeShardKey(wc);
        ProcessingShard shard = this.processingShards.computeIfAbsent(shardKey, ignore -> new ProcessingShard<K, V>(shardKey, this.options, this.wm.getPm()));
        shard.addWorkContainer(wc);
    }

    void removeShardIfEmpty(ShardKey key) {
        Optional<ProcessingShard<K, V>> shardOpt = this.getShard(key);
        boolean keyOrdering = this.options.getOrdering().equals((Object)ParallelConsumerOptions.ProcessingOrder.KEY);
        if (keyOrdering && shardOpt.isPresent() && shardOpt.get().isEmpty()) {
            log.trace("Removing empty shard (key: {})", (Object)key);
            this.processingShards.remove(key);
        }
    }

    public void onSuccess(WorkContainer<?, ?> wc) {
        this.retryQueue.remove(wc);
        ShardKey key = this.computeShardKey(wc);
        Optional<ProcessingShard<K, V>> shardOptional = this.getShard(key);
        if (shardOptional.isPresent()) {
            shardOptional.get().onSuccess(wc);
            this.removeShardIfEmpty(key);
        } else {
            log.trace("Dropping successful result for revoked partition {}. Record in question was: {}", (Object)key, wc.getCr());
        }
    }

    public void onFailure(WorkContainer<?, ?> wc) {
        log.debug("Work FAILED");
        this.retryQueue.add(wc);
    }

    public Optional<Duration> getLowestRetryTime() {
        for (WorkContainer<?, ?> workContainer : this.retryQueue) {
            if (!workContainer.isNotInFlight()) continue;
            return Optional.of(workContainer.getDelayUntilRetryDue());
        }
        return Optional.empty();
    }

    public List<WorkContainer<K, V>> getWorkIfAvailable(int requestedMaxWorkToRetrieve) {
        LoopingResumingIterator<ShardKey, ProcessingShard<K, V>> shardQueueIterator = new LoopingResumingIterator<ShardKey, ProcessingShard<K, V>>(this.iterationResumePoint, this.processingShards);
        ArrayList<WorkContainer<K, V>> workFromAllShards = new ArrayList<WorkContainer<K, V>>();
        Optional<Map.Entry<ShardKey, ProcessingShard<K, V>>> next = shardQueueIterator.next();
        while (workFromAllShards.size() < requestedMaxWorkToRetrieve && next.isPresent()) {
            Optional<Map.Entry<ShardKey, ProcessingShard<K, V>>> shardEntry = next;
            ProcessingShard<K, V> shard = shardEntry.get().getValue();
            int remainingToGet = requestedMaxWorkToRetrieve - workFromAllShards.size();
            ArrayList<WorkContainer<K, V>> work = shard.getWorkIfAvailable(remainingToGet);
            workFromAllShards.addAll(work);
            next = shardQueueIterator.next();
        }
        if (workFromAllShards.size() >= requestedMaxWorkToRetrieve) {
            log.debug("Work taken is now over max (iteration resume point is {})", this.iterationResumePoint);
        }
        this.updateResumePoint(next);
        return workFromAllShards;
    }

    public boolean removeStaleContainers() {
        boolean removed = this.processingShards.values().stream().map(ProcessingShard::removeStaleWorkContainersFromShard).anyMatch(res -> res.equals(true));
        if (removed) {
            log.debug("there are stale work containers removed");
        }
        return removed;
    }

    private void updateResumePoint(Optional<Map.Entry<ShardKey, ProcessingShard<K, V>>> lastShard) {
        this.iterationResumePoint = lastShard.map(Map.Entry::getKey);
        if (this.iterationResumePoint.isPresent()) {
            log.debug("Work taken is now over max, stopping (saving iteration resume point {})", this.iterationResumePoint);
        }
    }

    private void initMetrics() {
        this.shardsSizeGauge = this.pcMetrics.gaugeFromMetricDef(PCMetricsDef.SHARDS_SIZE, this, shardManager -> shardManager.processingShards.values().stream().mapToInt(processingShard -> processingShard.getEntries().size()).sum(), new Tag[0]);
        this.numberOfShardsGauge = this.pcMetrics.gaugeFromMetricDef(PCMetricsDef.NUMBER_OF_SHARDS, this, shardManager -> shardManager.processingShards.keySet().size(), new Tag[0]);
    }

    public ParallelConsumerOptions<?, ?> getOptions() {
        return this.options;
    }

    private Map<ShardKey, ProcessingShard<K, V>> getProcessingShards() {
        return this.processingShards;
    }

    Comparator<WorkContainer<?, ?>> getRetryQueueWorkContainerComparator() {
        return this.retryQueueWorkContainerComparator;
    }

    NavigableSet<WorkContainer<?, ?>> getRetryQueue() {
        return this.retryQueue;
    }
}

