/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.rag;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
import dev.langchain4j.rag.content.injector.ContentInjector;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.router.DefaultQueryRouter;
import dev.langchain4j.rag.query.router.QueryRouter;
import dev.langchain4j.rag.query.transformer.QueryTransformer;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.Mockito;

class DefaultRetrievalAugmentorTest {
    DefaultRetrievalAugmentorTest() {
    }

    @ParameterizedTest
    @MethodSource(value={"executors"})
    void should_augment_user_message(Executor executor) {
        Query query1 = Query.from((String)"query 1");
        Query query2 = Query.from((String)"query 2");
        QueryTransformer queryTransformer = (QueryTransformer)Mockito.spy((Object)new TestQueryTransformer(query1, query2));
        Content content1 = Content.from((String)"content 1");
        Content content2 = Content.from((String)"content 2");
        ContentRetriever contentRetriever1 = (ContentRetriever)Mockito.spy((Object)new TestContentRetriever(content1, content2));
        Content content3 = Content.from((String)"content 3");
        Content content4 = Content.from((String)"content 4");
        ContentRetriever contentRetriever2 = (ContentRetriever)Mockito.spy((Object)new TestContentRetriever(content3, content4));
        QueryRouter queryRouter = (QueryRouter)Mockito.spy((Object)new DefaultQueryRouter(new ContentRetriever[]{contentRetriever1, contentRetriever2}));
        ContentAggregator contentAggregator = (ContentAggregator)Mockito.spy((Object)new TestContentAggregator());
        ContentInjector contentInjector = (ContentInjector)Mockito.spy((Object)new TestContentInjector());
        DefaultRetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder().queryTransformer(queryTransformer).queryRouter(queryRouter).contentAggregator(contentAggregator).contentInjector(contentInjector).executor(executor).build();
        UserMessage userMessage = UserMessage.from((String)"query");
        Metadata metadata = Metadata.from((UserMessage)userMessage, null, null);
        UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata);
        Assertions.assertThat((String)augmented.singleText()).isEqualTo("query\ncontent 1\ncontent 2\ncontent 3\ncontent 4\ncontent 1\ncontent 2\ncontent 3\ncontent 4");
        ((QueryTransformer)Mockito.verify((Object)queryTransformer)).transform(Query.from((String)"query", (Metadata)metadata));
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{queryTransformer});
        ((QueryRouter)Mockito.verify((Object)queryRouter)).route(query1);
        ((QueryRouter)Mockito.verify((Object)queryRouter)).route(query2);
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{queryRouter});
        ((ContentRetriever)Mockito.verify((Object)contentRetriever1)).retrieve(query1);
        ((ContentRetriever)Mockito.verify((Object)contentRetriever1)).retrieve(query2);
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{contentRetriever1});
        ((ContentRetriever)Mockito.verify((Object)contentRetriever2)).retrieve(query1);
        ((ContentRetriever)Mockito.verify((Object)contentRetriever2)).retrieve(query2);
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{contentRetriever2});
        HashMap<Query, List<List>> queryToContents = new HashMap<Query, List<List>>();
        queryToContents.put(query1, Arrays.asList(Arrays.asList(content1, content2), Arrays.asList(content3, content4)));
        queryToContents.put(query2, Arrays.asList(Arrays.asList(content1, content2), Arrays.asList(content3, content4)));
        ((ContentAggregator)Mockito.verify((Object)contentAggregator)).aggregate(queryToContents);
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{contentAggregator});
        ((ContentInjector)Mockito.verify((Object)contentInjector)).inject(Arrays.asList(content1, content2, content3, content4, content1, content2, content3, content4), userMessage);
        ((ContentInjector)Mockito.verify((Object)contentInjector)).inject(Arrays.asList(content1, content2, content3, content4, content1, content2, content3, content4), (ChatMessage)userMessage);
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{contentInjector});
    }

    @ParameterizedTest
    @MethodSource(value={"executors"})
    void should_not_augment_when_router_does_not_return_retrievers(Executor executor) {
        List<ContentRetriever> retrievers = Collections.emptyList();
        QueryRouter queryRouter = (QueryRouter)Mockito.spy((Object)new TestQueryRouter(retrievers));
        DefaultRetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder().queryRouter(queryRouter).executor(executor).build();
        UserMessage userMessage = UserMessage.from((String)"query");
        Metadata metadata = Metadata.from((UserMessage)userMessage, null, null);
        UserMessage augmentedUserMessage = retrievalAugmentor.augment(userMessage, metadata);
        Assertions.assertThat((Object)augmentedUserMessage).isEqualTo((Object)userMessage);
        ((QueryRouter)Mockito.verify((Object)queryRouter)).route(Query.from((String)"query", (Metadata)metadata));
        Mockito.verifyNoMoreInteractions((Object[])new Object[]{queryRouter});
    }

    static Stream<Arguments> executors() {
        return Stream.builder().add(Arguments.of((Object[])new Object[]{Executors.newCachedThreadPool()})).add(Arguments.of((Object[])new Object[]{Executors.newFixedThreadPool(1)})).add(Arguments.of((Object[])new Object[]{Executors.newFixedThreadPool(2)})).add(Arguments.of((Object[])new Object[]{Executors.newFixedThreadPool(3)})).add(Arguments.of((Object[])new Object[]{Executors.newFixedThreadPool(4)})).build();
    }

    static class TestQueryTransformer
    implements QueryTransformer {
        private final List<Query> queries;

        TestQueryTransformer(Query ... queries) {
            this.queries = Arrays.asList(queries);
        }

        public Collection<Query> transform(Query query) {
            return this.queries;
        }
    }

    static class TestContentRetriever
    implements ContentRetriever {
        private final List<Content> contents;

        TestContentRetriever(Content ... contents) {
            this.contents = Arrays.asList(contents);
        }

        public List<Content> retrieve(Query query) {
            return this.contents;
        }
    }

    static class TestContentAggregator
    implements ContentAggregator {
        TestContentAggregator() {
        }

        public List<Content> aggregate(Map<Query, Collection<List<Content>>> queryToContents) {
            return queryToContents.values().stream().flatMap(Collection::stream).flatMap(Collection::stream).collect(Collectors.toList());
        }
    }

    static class TestContentInjector
    implements ContentInjector {
        TestContentInjector() {
        }

        public UserMessage inject(List<Content> contents, UserMessage userMessage) {
            String joinedContents = contents.stream().map(it -> it.textSegment().text()).collect(Collectors.joining("\n"));
            return UserMessage.from((String)(userMessage.text() + "\n" + joinedContents));
        }
    }

    static class TestQueryRouter
    implements QueryRouter {
        private final Collection<ContentRetriever> retrievers;

        TestQueryRouter(Collection<ContentRetriever> retrievers) {
            this.retrievers = retrievers;
        }

        public Collection<ContentRetriever> route(Query query) {
            return this.retrievers;
        }
    }
}

