/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.benchmark.search.aggregations;

import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.IntConsumer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.BytesRef;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.opensearch.action.OriginalIndices;
import org.opensearch.action.search.QueryPhaseResultConsumer;
import org.opensearch.action.search.SearchPhaseController;
import org.opensearch.action.search.SearchProgressListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.NoopCircuitBreaker;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.index.Index;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.SearchModule;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.BucketOrder;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.aggregations.MultiBucketConsumerService;
import org.opensearch.search.aggregations.bucket.terms.StringTerms;
import org.opensearch.search.aggregations.bucket.terms.TermsAggregator;
import org.opensearch.search.aggregations.pipeline.PipelineAggregator;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.query.QuerySearchResult;

@Warmup(iterations=5)
@Measurement(iterations=7)
@BenchmarkMode(value={Mode.AverageTime})
@OutputTimeUnit(value=TimeUnit.MILLISECONDS)
@State(value=Scope.Thread)
@Fork(value=1)
public class TermsReduceBenchmark {
    private final SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList());
    private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(this.searchModule.getNamedWriteables());
    private final SearchPhaseController controller = new SearchPhaseController(this.namedWriteableRegistry, req -> new InternalAggregation.ReduceContextBuilder(this){

        public InternalAggregation.ReduceContext forPartialReduction() {
            return InternalAggregation.ReduceContext.forPartialReduction(null, null, () -> PipelineAggregator.PipelineTree.EMPTY);
        }

        public InternalAggregation.ReduceContext forFinalReduction() {
            MultiBucketConsumerService.MultiBucketConsumer bucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer(Integer.MAX_VALUE, new NoneCircuitBreakerService().getBreaker("request"));
            return InternalAggregation.ReduceContext.forFinalReduction(null, null, (IntConsumer)bucketConsumer, (PipelineAggregator.PipelineTree)PipelineAggregator.PipelineTree.EMPTY);
        }
    });
    @Param(value={"32", "512"})
    private int bufferSize;

    @Benchmark
    public SearchPhaseController.ReducedQueryPhase reduceAggs(TermsList candidateList) throws Exception {
        ArrayList<QuerySearchResult> shards = new ArrayList<QuerySearchResult>();
        for (int i = 0; i < candidateList.size(); ++i) {
            QuerySearchResult result = new QuerySearchResult();
            result.setShardIndex(i);
            result.from(0);
            result.size(0);
            result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(1000L, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]), Float.NaN), new DocValueFormat[]{DocValueFormat.RAW});
            result.aggregations(candidateList.get(i));
            result.setSearchShardTarget(new SearchShardTarget("node", new ShardId(new Index("index", "index"), i), null, OriginalIndices.NONE));
            shards.add(result);
        }
        SearchRequest request = new SearchRequest();
        request.source(new SearchSourceBuilder().size(0).aggregation((AggregationBuilder)AggregationBuilders.terms((String)"test")));
        request.setBatchedReduceSize(this.bufferSize);
        ExecutorService executor = Executors.newFixedThreadPool(1);
        QueryPhaseResultConsumer consumer = new QueryPhaseResultConsumer(request, (Executor)executor, (CircuitBreaker)new NoopCircuitBreaker("request"), this.controller, SearchProgressListener.NOOP, this.namedWriteableRegistry, shards.size(), exc -> {});
        CountDownLatch latch = new CountDownLatch(shards.size());
        for (int i = 0; i < shards.size(); ++i) {
            consumer.consumeResult((SearchPhaseResult)shards.get(i), () -> latch.countDown());
        }
        latch.await();
        SearchPhaseController.ReducedQueryPhase phase = consumer.reduce();
        executor.shutdownNow();
        return phase;
    }

    @State(value=Scope.Benchmark)
    public static class TermsList
    extends AbstractList<InternalAggregations> {
        @Param(value={"1600172297"})
        long seed;
        @Param(value={"64", "128", "512"})
        int numShards;
        @Param(value={"100"})
        int topNSize;
        @Param(value={"1", "10", "100"})
        int cardinalityFactor;
        List<InternalAggregations> aggsList;

        @Setup
        public void setup() {
            int i;
            this.aggsList = new ArrayList<InternalAggregations>();
            Random rand = new Random(this.seed);
            int cardinality = this.cardinalityFactor * this.topNSize;
            BytesRef[] dict = new BytesRef[cardinality];
            for (i = 0; i < dict.length; ++i) {
                dict[i] = new BytesRef((CharSequence)Long.toString(rand.nextLong()));
            }
            for (i = 0; i < this.numShards; ++i) {
                this.aggsList.add(InternalAggregations.from(Collections.singletonList(this.newTerms(rand, dict, true))));
            }
        }

        private StringTerms newTerms(Random rand, BytesRef[] dict, boolean withNested) {
            HashSet<BytesRef> randomTerms = new HashSet<BytesRef>();
            for (int i = 0; i < this.topNSize; ++i) {
                randomTerms.add(dict[rand.nextInt(dict.length)]);
            }
            ArrayList<StringTerms.Bucket> buckets = new ArrayList<StringTerms.Bucket>();
            for (BytesRef term : randomTerms) {
                InternalAggregations subAggs = withNested ? InternalAggregations.from(Collections.singletonList(this.newTerms(rand, dict, false))) : InternalAggregations.EMPTY;
                buckets.add(new StringTerms.Bucket(term, (long)rand.nextInt(10000), subAggs, true, 0L, DocValueFormat.RAW));
            }
            Collections.sort(buckets, (a, b) -> a.compareKey(b));
            return new StringTerms("terms", BucketOrder.key((boolean)true), BucketOrder.count((boolean)false), Collections.emptyMap(), DocValueFormat.RAW, this.numShards, true, 0L, buckets, 0L, new TermsAggregator.BucketCountThresholds(1L, 0L, this.topNSize, this.numShards));
        }

        @Override
        public InternalAggregations get(int index) {
            return this.aggsList.get(index);
        }

        @Override
        public int size() {
            return this.aggsList.size();
        }
    }
}

