/*
 * Decompiled with CFR 0.152.
 */
package com.google.adk.flows.llmflows;

import com.google.adk.Telemetry;
import com.google.adk.agents.ActiveStreamingTool;
import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.CallbackContext;
import com.google.adk.agents.Callbacks;
import com.google.adk.agents.InvocationContext;
import com.google.adk.agents.LiveRequest;
import com.google.adk.agents.LlmAgent;
import com.google.adk.agents.ReadonlyContext;
import com.google.adk.agents.RunConfig;
import com.google.adk.events.Event;
import com.google.adk.flows.BaseFlow;
import com.google.adk.flows.llmflows.Functions;
import com.google.adk.flows.llmflows.RequestProcessor;
import com.google.adk.flows.llmflows.ResponseProcessor;
import com.google.adk.models.BaseLlm;
import com.google.adk.models.BaseLlmConnection;
import com.google.adk.models.LlmCallsLimitExceededException;
import com.google.adk.models.LlmRegistry;
import com.google.adk.models.LlmRequest;
import com.google.adk.models.LlmResponse;
import com.google.adk.tools.ToolContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.genai.types.FunctionResponse;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.context.Scope;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.core.CompletableObserver;
import io.reactivex.rxjava3.core.CompletableSource;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.MaybeSource;
import io.reactivex.rxjava3.core.Single;
import io.reactivex.rxjava3.core.SingleSource;
import io.reactivex.rxjava3.disposables.Disposable;
import io.reactivex.rxjava3.observers.DisposableCompletableObserver;
import io.reactivex.rxjava3.schedulers.Schedulers;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseLlmFlow
implements BaseFlow {
    private static final Logger logger = LoggerFactory.getLogger(BaseLlmFlow.class);
    protected final List<RequestProcessor> requestProcessors;
    protected final List<ResponseProcessor> responseProcessors;
    protected int stepsCompleted = 0;
    protected final int maxSteps;

    public BaseLlmFlow(List<RequestProcessor> requestProcessors, List<ResponseProcessor> responseProcessors) {
        this(requestProcessors, responseProcessors, Optional.empty());
    }

    public BaseLlmFlow(List<RequestProcessor> requestProcessors, List<ResponseProcessor> responseProcessors, Optional<Integer> maxSteps) {
        this.requestProcessors = requestProcessors;
        this.responseProcessors = responseProcessors;
        this.maxSteps = maxSteps.orElse(Integer.MAX_VALUE);
    }

    protected Single<RequestProcessor.RequestProcessingResult> preprocess(InvocationContext context, LlmRequest llmRequest) {
        ArrayList eventIterables = new ArrayList();
        LlmAgent agent = (LlmAgent)context.agent();
        Single currentLlmRequest = Single.just((Object)llmRequest);
        for (RequestProcessor processor : this.requestProcessors) {
            currentLlmRequest = currentLlmRequest.flatMap(request -> processor.processRequest(context, (LlmRequest)request)).doOnSuccess(result -> {
                if (result.events() != null) {
                    eventIterables.add(result.events());
                }
            }).map(RequestProcessor.RequestProcessingResult::updatedRequest);
        }
        return currentLlmRequest.flatMap(processedRequest -> {
            LlmRequest.Builder updatedRequestBuilder = processedRequest.toBuilder();
            return agent.canonicalTools(new ReadonlyContext(context)).concatMapCompletable(tool -> tool.processLlmRequest(updatedRequestBuilder, ToolContext.builder(context).build())).andThen((SingleSource)Single.fromCallable(() -> {
                Iterable combinedEvents = Iterables.concat((Iterable)eventIterables);
                return RequestProcessor.RequestProcessingResult.create(updatedRequestBuilder.build(), combinedEvents);
            }));
        });
    }

    protected Single<ResponseProcessor.ResponseProcessingResult> postprocess(InvocationContext context, Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) {
        ArrayList eventIterables = new ArrayList();
        Single currentLlmResponse = Single.just((Object)llmResponse);
        for (ResponseProcessor processor : this.responseProcessors) {
            currentLlmResponse = currentLlmResponse.flatMap(response -> processor.processResponse(context, (LlmResponse)response)).doOnSuccess(result -> {
                if (result.events() != null) {
                    eventIterables.add(result.events());
                }
            }).map(ResponseProcessor.ResponseProcessingResult::updatedResponse);
        }
        return currentLlmResponse.flatMap(updatedResponse -> {
            if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() && !updatedResponse.interrupted().orElse(false).booleanValue() && !updatedResponse.turnComplete().orElse(false).booleanValue()) {
                return Single.just((Object)ResponseProcessor.ResponseProcessingResult.create(updatedResponse, Iterables.concat((Iterable)eventIterables), Optional.empty()));
            }
            Event modelResponseEvent = this.buildModelResponseEvent(baseEventForLlmResponse, llmRequest, (LlmResponse)updatedResponse);
            eventIterables.add(Collections.singleton(modelResponseEvent));
            Maybe<Event> maybeFunctionCallEvent = modelResponseEvent.functionCalls().isEmpty() ? Maybe.empty() : (context.runConfig().streamingMode() == RunConfig.StreamingMode.BIDI ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()));
            return maybeFunctionCallEvent.map(Optional::of).defaultIfEmpty(Optional.empty()).map(functionCallEventOpt -> {
                Optional<String> transferToAgent = Optional.empty();
                if (functionCallEventOpt.isPresent()) {
                    Event functionCallEvent = (Event)functionCallEventOpt.get();
                    eventIterables.add(Collections.singleton(functionCallEvent));
                    transferToAgent = functionCallEvent.actions().transferToAgent();
                }
                Iterable combinedEvents = Iterables.concat((Iterable)eventIterables);
                return ResponseProcessor.ResponseProcessingResult.create(updatedResponse, combinedEvents, transferToAgent);
            });
        });
    }

    private Flowable<LlmResponse> callLlm(InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) {
        LlmAgent agent = (LlmAgent)context.agent();
        LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder();
        return this.handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage).flatMapPublisher(beforeResponse -> {
            if (beforeResponse.isPresent()) {
                return Flowable.just((Object)((LlmResponse)beforeResponse.get()));
            }
            BaseLlm llm = agent.resolvedModel().model().isPresent() ? agent.resolvedModel().model().get() : LlmRegistry.getLlm(agent.resolvedModel().modelName().get());
            return Flowable.defer(() -> {
                Span llmCallSpan = Telemetry.getTracer().spanBuilder("call_llm").startSpan();
                try (Scope scope = llmCallSpan.makeCurrent();){
                    Flowable flowable = llm.generateContent(llmRequestBuilder.build(), context.runConfig().streamingMode() == RunConfig.StreamingMode.SSE).onErrorResumeNext(exception -> context.pluginManager().runOnModelErrorCallback(new CallbackContext(context, eventForCallbackUsage.actions()), llmRequest, (Throwable)exception).switchIfEmpty((SingleSource)Single.error((Throwable)exception)).toFlowable()).doOnNext(llmResp -> {
                        try (Scope innerScope = llmCallSpan.makeCurrent();){
                            Telemetry.traceCallLlm(context, eventForCallbackUsage.id(), llmRequest, llmResp);
                        }
                    }).doOnError(error -> {
                        llmCallSpan.setStatus(StatusCode.ERROR, error.getMessage());
                        llmCallSpan.recordException(error);
                    }).doFinally(() -> ((Span)llmCallSpan).end());
                    return flowable;
                }
            }).concatMap(llmResp -> this.handleAfterModelCallback(context, (LlmResponse)llmResp, eventForCallbackUsage).toFlowable());
        });
    }

    private Single<Optional<LlmResponse>> handleBeforeModelCallback(InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) {
        Event callbackEvent = modelResponseEvent.toBuilder().build();
        CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());
        Maybe<LlmResponse> pluginResult = context.pluginManager().runBeforeModelCallback(callbackContext, llmRequestBuilder.build());
        LlmAgent agent = (LlmAgent)context.agent();
        Optional<List<Callbacks.BeforeModelCallback>> callbacksOpt = agent.beforeModelCallback();
        if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
            return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty());
        }
        List<Callbacks.BeforeModelCallback> callbacks = callbacksOpt.get();
        Maybe callbackResult = Maybe.defer(() -> Flowable.fromIterable((Iterable)callbacks).concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)).firstElement());
        return pluginResult.switchIfEmpty((MaybeSource)callbackResult).map(Optional::of).defaultIfEmpty(Optional.empty());
    }

    private Single<LlmResponse> handleAfterModelCallback(InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) {
        Event callbackEvent = modelResponseEvent.toBuilder().build();
        CallbackContext callbackContext = new CallbackContext(context, callbackEvent.actions());
        Maybe<LlmResponse> pluginResult = context.pluginManager().runAfterModelCallback(callbackContext, llmResponse);
        LlmAgent agent = (LlmAgent)context.agent();
        Optional<List<Callbacks.AfterModelCallback>> callbacksOpt = agent.afterModelCallback();
        if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
            return pluginResult.defaultIfEmpty((Object)llmResponse);
        }
        Maybe callbackResult = Maybe.defer(() -> Flowable.fromIterable((Iterable)((Iterable)callbacksOpt.get())).concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)).firstElement());
        return pluginResult.switchIfEmpty((MaybeSource)callbackResult).defaultIfEmpty((Object)llmResponse);
    }

    private Flowable<Event> runOneStep(InvocationContext context) {
        LlmRequest initialLlmRequest = LlmRequest.builder().build();
        return this.preprocess(context, initialLlmRequest).flatMapPublisher(preResult -> {
            LlmRequest llmRequestAfterPreprocess = preResult.updatedRequest();
            Iterable<Event> preEvents = preResult.events();
            if (context.endInvocation()) {
                logger.debug("End invocation requested during preprocessing.");
                return Flowable.fromIterable(preEvents);
            }
            try {
                context.incrementLlmCallsCount();
            }
            catch (LlmCallsLimitExceededException e) {
                logger.error("LLM calls limit exceeded.", (Throwable)e);
                return Flowable.fromIterable(preEvents).concatWith((Publisher)Flowable.error((Throwable)e));
            }
            Event mutableEventTemplate = Event.builder().id(Event.generateEventId()).invocationId(context.invocationId()).author(context.agent().name()).branch(context.branch()).build();
            mutableEventTemplate.setTimestamp(0L);
            Flowable restOfFlow = this.callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate).concatMap(llmResponse -> {
                Single<ResponseProcessor.ResponseProcessingResult> postResultSingle = this.postprocess(context, mutableEventTemplate, llmRequestAfterPreprocess, (LlmResponse)llmResponse);
                return postResultSingle.doOnSuccess(ignored -> {
                    String oldId = mutableEventTemplate.id();
                    mutableEventTemplate.setId(Event.generateEventId());
                    logger.debug("Updated mutableEventTemplate ID from {} to {} for next LlmResponse", (Object)oldId, (Object)mutableEventTemplate.id());
                }).toFlowable();
            }).concatMap(postResult -> {
                Flowable postProcessedEvents = Flowable.fromIterable(postResult.events());
                if (postResult.transferToAgent().isPresent()) {
                    String agentToTransfer = postResult.transferToAgent().get();
                    logger.debug("Transferring to agent: {}", (Object)agentToTransfer);
                    BaseAgent rootAgent = context.agent().rootAgent();
                    BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer);
                    if (nextAgent == null) {
                        String errorMsg = "Agent not found for transfer: " + agentToTransfer;
                        logger.error(errorMsg);
                        return postProcessedEvents.concatWith((Publisher)Flowable.error((Throwable)new IllegalStateException(errorMsg)));
                    }
                    return postProcessedEvents.concatWith((Publisher)Flowable.defer(() -> nextAgent.runAsync(context)));
                }
                return postProcessedEvents;
            });
            return restOfFlow.startWithIterable(preEvents);
        });
    }

    @Override
    public Flowable<Event> run(InvocationContext invocationContext) {
        Flowable currentStepEvents = this.runOneStep(invocationContext).cache();
        if (++this.stepsCompleted >= this.maxSteps) {
            logger.debug("Ending flow execution because max steps reached.");
            return currentStepEvents;
        }
        return currentStepEvents.concatWith((Publisher)currentStepEvents.toList().flatMapPublisher(eventList -> {
            if (eventList.isEmpty() || ((Event)Iterables.getLast((Iterable)eventList)).finalResponse() || ((Event)Iterables.getLast((Iterable)eventList)).actions().endInvocation().orElse(false).booleanValue()) {
                logger.debug("Ending flow execution based on final response, endInvocation action or empty event list.");
                return Flowable.empty();
            }
            logger.debug("Continuing to next step of the flow.");
            return Flowable.defer(() -> this.run(invocationContext));
        }));
    }

    @Override
    public Flowable<Event> runLive(InvocationContext invocationContext) {
        LlmRequest llmRequest = LlmRequest.builder().build();
        return this.preprocess(invocationContext, llmRequest).flatMapPublisher(preResult -> {
            LlmRequest llmRequestAfterPreprocess = preResult.updatedRequest();
            if (invocationContext.endInvocation()) {
                return Flowable.fromIterable(preResult.events());
            }
            String eventIdForSendData = Event.generateEventId();
            LlmAgent agent = (LlmAgent)invocationContext.agent();
            BaseLlm llm = agent.resolvedModel().model().isPresent() ? agent.resolvedModel().model().get() : LlmRegistry.getLlm(agent.resolvedModel().modelName().get());
            final BaseLlmConnection connection = llm.connect(llmRequestAfterPreprocess);
            Completable historySent = llmRequestAfterPreprocess.contents().isEmpty() ? Completable.complete() : Completable.defer(() -> {
                Span sendDataSpan = Telemetry.getTracer().spanBuilder("send_data").startSpan();
                try (Scope scope = sendDataSpan.makeCurrent();){
                    Completable completable = connection.sendHistory(llmRequestAfterPreprocess.contents()).doOnComplete(() -> {
                        try (Scope innerScope = sendDataSpan.makeCurrent();){
                            Telemetry.traceSendData(invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents());
                        }
                    }).doOnError(error -> {
                        sendDataSpan.setStatus(StatusCode.ERROR, error.getMessage());
                        sendDataSpan.recordException(error);
                        try (Scope innerScope = sendDataSpan.makeCurrent();){
                            Telemetry.traceSendData(invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents());
                        }
                    }).doFinally(() -> ((Span)sendDataSpan).end());
                    return completable;
                }
            });
            Flowable liveRequests = invocationContext.liveRequestQueue().get().get().doOnNext(request -> {
                if (!invocationContext.activeStreamingTools().isEmpty()) {
                    for (ActiveStreamingTool activeStreamingTool : invocationContext.activeStreamingTools().values()) {
                        if (activeStreamingTool.stream() == null) continue;
                        activeStreamingTool.stream().send((LiveRequest)request);
                    }
                }
            });
            Disposable sendTask = (Disposable)historySent.observeOn(agent.executor().map(Schedulers::from).orElse(Schedulers.io())).andThen((CompletableSource)liveRequests.onBackpressureBuffer().concatMapCompletable(request -> {
                if (request.content().isPresent()) {
                    return connection.sendContent(request.content().get());
                }
                if (request.blob().isPresent()) {
                    return connection.sendRealtime(request.blob().get());
                }
                return Completable.fromAction(connection::close);
            })).subscribeWith((CompletableObserver)new DisposableCompletableObserver(this){
                final /* synthetic */ BaseLlmFlow this$0;
                {
                    this.this$0 = this$0;
                }

                public void onComplete() {
                    connection.close();
                }

                public void onError(Throwable e) {
                    connection.close(e);
                }
            });
            Event.Builder liveEventBuilderTemplate = Event.builder().invocationId(invocationContext.invocationId()).author(invocationContext.agent().name()).branch(invocationContext.branch());
            Flowable receiveFlow = connection.receive().flatMapSingle(llmResponse -> {
                Event baseEventForThisLlmResponse = liveEventBuilderTemplate.id(Event.generateEventId()).build();
                return this.postprocess(invocationContext, baseEventForThisLlmResponse, llmRequestAfterPreprocess, (LlmResponse)llmResponse);
            }).flatMap(postResult -> {
                Flowable events = Flowable.fromIterable(postResult.events());
                if (postResult.transferToAgent().isPresent()) {
                    BaseAgent rootAgent = invocationContext.agent().rootAgent();
                    BaseAgent nextAgent = rootAgent.findAgent(postResult.transferToAgent().get());
                    if (nextAgent == null) {
                        throw new IllegalStateException("Agent not found: " + postResult.transferToAgent().get());
                    }
                    Flowable<Event> nextAgentEvents = nextAgent.runLive(invocationContext);
                    events = Flowable.concat((Publisher)events, nextAgentEvents);
                }
                return events;
            }).doOnNext(event -> {
                ImmutableList<FunctionResponse> functionResponses = event.functionResponses();
                if (!functionResponses.isEmpty()) {
                    invocationContext.liveRequestQueue().get().content(event.content().get());
                }
                if (functionResponses.stream().anyMatch(functionResponse -> functionResponse.name().orElse("").equals("transferToAgent")) || event.actions().endInvocation().orElse(false).booleanValue()) {
                    sendTask.dispose();
                    connection.close();
                }
            });
            return receiveFlow.takeWhile(event -> event.actions().endInvocation().orElse(false) == false).startWithIterable(preResult.events());
        });
    }

    private Event buildModelResponseEvent(Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) {
        Event.Builder eventBuilder = baseEventForLlmResponse.toBuilder().content(llmResponse.content()).partial(llmResponse.partial()).errorCode(llmResponse.errorCode()).errorMessage(llmResponse.errorMessage()).interrupted(llmResponse.interrupted()).turnComplete(llmResponse.turnComplete()).groundingMetadata(llmResponse.groundingMetadata());
        Event event = eventBuilder.build();
        if (!event.functionCalls().isEmpty()) {
            Functions.populateClientFunctionCallId(event);
            Set<String> longRunningToolIds = Functions.getLongRunningFunctionCalls(event.functionCalls(), llmRequest.tools());
            if (!longRunningToolIds.isEmpty()) {
                event.setLongRunningToolIds(Optional.of(longRunningToolIds));
            }
        }
        return event;
    }
}

