/*
 * Decompiled with CFR 0.152.
 */
package com.google.adk.tools.mcp;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.JsonBaseModel;
import com.google.adk.agents.ReadonlyContext;
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.BaseToolset;
import com.google.adk.tools.NamedToolPredicate;
import com.google.adk.tools.ToolPredicate;
import com.google.adk.tools.mcp.McpAsyncTool;
import com.google.adk.tools.mcp.McpSessionManager;
import com.google.adk.tools.mcp.McpToolsetException;
import com.google.adk.tools.mcp.SseServerParameters;
import com.google.common.collect.ImmutableList;
import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.client.transport.ServerParameters;
import io.modelcontextprotocol.spec.McpSchema;
import io.reactivex.rxjava3.core.Flowable;
import io.reactivex.rxjava3.core.Maybe;
import io.reactivex.rxjava3.core.Single;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.util.retry.RetrySpec;

public class McpAsyncToolset
implements BaseToolset {
    private static final Logger logger = LoggerFactory.getLogger(McpAsyncToolset.class);
    private static final int MAX_RETRIES = 3;
    private static final Duration RETRY_DELAY = Duration.ofMillis(100L);
    private final McpSessionManager mcpSessionManager;
    private final ObjectMapper objectMapper;
    private final ToolPredicate toolFilter;
    private final AtomicReference<Mono<List<McpAsyncTool>>> mcpTools = new AtomicReference();

    public McpAsyncToolset(SseServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
        Objects.requireNonNull(connectionParams);
        Objects.requireNonNull(objectMapper);
        this.objectMapper = objectMapper;
        this.mcpSessionManager = new McpSessionManager(connectionParams);
        this.toolFilter = toolFilter;
    }

    public McpAsyncToolset(ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
        Objects.requireNonNull(connectionParams);
        Objects.requireNonNull(objectMapper);
        this.objectMapper = objectMapper;
        this.mcpSessionManager = new McpSessionManager(connectionParams);
        this.toolFilter = toolFilter;
    }

    @Override
    public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
        return Maybe.defer(() -> Maybe.fromCompletionStage((CompletionStage)this.initAndGetTools().toFuture())).defaultIfEmpty((Object)ImmutableList.of()).map(tools -> tools.stream().filter(tool -> this.isToolSelected((BaseTool)tool, Optional.ofNullable(this.toolFilter), Optional.ofNullable(readonlyContext))).toList()).onErrorResumeNext(err -> {
            if (err instanceof McpToolsetException) {
                return Single.error((Throwable)err);
            }
            return Single.error((Throwable)new McpToolsetException.McpInitializationException("Failed to reinitialize session during tool loading retry (unexpected error).", (Throwable)err));
        }).flattenAsFlowable(it -> it);
    }

    private Mono<List<McpAsyncTool>> initAndGetTools() {
        return this.mcpTools.accumulateAndGet(null, (prev, _ignore) -> {
            if (prev == null) {
                return this.initTools().cache();
            }
            return prev;
        });
    }

    private Mono<List<McpAsyncTool>> initTools() {
        return Mono.defer(() -> {
            McpAsyncClient mcpSession = this.mcpSessionManager.createAsyncSession();
            return mcpSession.initialize().doOnSuccess(initResult -> logger.debug("Initialize Client Result: {}", initResult)).thenReturn((Object)mcpSession);
        }).flatMap(mcpSession -> mcpSession.listTools().map(toolsResponse -> toolsResponse.tools().stream().map(tool -> new McpAsyncTool((McpSchema.Tool)tool, (McpAsyncClient)mcpSession, this.mcpSessionManager, this.objectMapper)).toList())).retryWhen(RetrySpec.from(retrySignal -> retrySignal.flatMap(signal -> {
            Throwable err = signal.failure();
            if (err instanceof IllegalArgumentException) {
                logger.error("Invalid argument encountered during tool loading.", err);
                return Mono.error((Throwable)new McpToolsetException.McpToolLoadingException("Invalid argument encountered during tool loading.", err));
            }
            long totalRetries = signal.totalRetries();
            logger.error("Unexpected error during tool loading, retry attempt " + (totalRetries + 1L), err);
            if (totalRetries < 3L) {
                logger.info("Reinitializing MCP session before next retry for unexpected error.");
                return Mono.just((Object)err).delayElement(RETRY_DELAY);
            }
            logger.error("Failed to load tools after multiple retries due to unexpected error.", err);
            return Mono.error((Throwable)new McpToolsetException.McpToolLoadingException("Failed to load tools after multiple retries due to unexpected error.", err));
        })));
    }

    @Override
    public void close() {
        Mono tools = this.mcpTools.getAndSet(null);
        if (tools != null) {
            tools.flatMapIterable(it -> it).flatMap(it -> ((McpAsyncClient)it.mcpSession).closeGracefully().onErrorResume(e -> {
                logger.error("Failed to close MCP session", e);
                return Mono.empty();
            })).doOnComplete(() -> logger.debug("MCP session closed successfully.")).subscribe();
        }
    }

    public static class Builder {
        private Object connectionParams = null;
        private ObjectMapper objectMapper = null;
        private ToolPredicate toolFilter = null;

        public Builder connectionParams(ServerParameters connectionParams) {
            this.connectionParams = connectionParams;
            return this;
        }

        public Builder connectionParams(SseServerParameters connectionParams) {
            this.connectionParams = connectionParams;
            return this;
        }

        public Builder objectMapper(ObjectMapper objectMapper) {
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder toolFilter(ToolPredicate toolFilter) {
            this.toolFilter = toolFilter;
            return this;
        }

        public Builder toolFilter(List<String> toolNames) {
            this.toolFilter = new NamedToolPredicate(toolNames);
            return this;
        }

        public McpAsyncToolset build() {
            Object object;
            if (this.objectMapper == null) {
                this.objectMapper = JsonBaseModel.getMapper();
            }
            if (this.toolFilter == null) {
                this.toolFilter = (tool, context) -> true;
            }
            if ((object = this.connectionParams) instanceof ServerParameters) {
                ServerParameters setSelectedParams = (ServerParameters)object;
                return new McpAsyncToolset(setSelectedParams, this.objectMapper, this.toolFilter);
            }
            object = this.connectionParams;
            if (object instanceof SseServerParameters) {
                SseServerParameters sseServerParameters = (SseServerParameters)object;
                return new McpAsyncToolset(sseServerParameters, this.objectMapper, this.toolFilter);
            }
            throw new IllegalArgumentException("connectionParams must be either ServerParameters or SseServerParameters");
        }
    }
}

