/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.bedrock.converse;

import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationConvention;
import io.micrometer.observation.ObservationRegistry;
import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.net.URL;
import java.net.URLConnection;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.bedrock.converse.BedrockChatOptions;
import org.springframework.ai.bedrock.converse.api.BedrockCacheOptions;
import org.springframework.ai.bedrock.converse.api.BedrockCacheStrategy;
import org.springframework.ai.bedrock.converse.api.BedrockMediaFormat;
import org.springframework.ai.bedrock.converse.api.ConverseApiUtils;
import org.springframework.ai.bedrock.converse.api.ConverseChatResponseStream;
import org.springframework.ai.bedrock.converse.api.URLValidator;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.ModelOptions;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.ContextView;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.model.CachePointBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseMetrics;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.DocumentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.DocumentSource;
import software.amazon.awssdk.services.bedrockruntime.model.ImageBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ImageSource;
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.S3Location;
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
import software.amazon.awssdk.services.bedrockruntime.model.VideoBlock;
import software.amazon.awssdk.services.bedrockruntime.model.VideoFormat;
import software.amazon.awssdk.services.bedrockruntime.model.VideoSource;

public class BedrockProxyChatModel
implements ChatModel {
    private static final Logger logger = LoggerFactory.getLogger(BedrockProxyChatModel.class);
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();
    private final BedrockRuntimeClient bedrockRuntimeClient;
    private final BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;
    private final BedrockChatOptions defaultOptions;
    private final ObservationRegistry observationRegistry;
    private final ToolCallingManager toolCallingManager;
    private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;
    private ChatModelObservationConvention observationConvention;

    public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, BedrockChatOptions defaultOptions, ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager) {
        this(bedrockRuntimeClient, bedrockRuntimeAsyncClient, defaultOptions, observationRegistry, toolCallingManager, (ToolExecutionEligibilityPredicate)new DefaultToolExecutionEligibilityPredicate());
    }

    public BedrockProxyChatModel(BedrockRuntimeClient bedrockRuntimeClient, BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, BedrockChatOptions defaultOptions, ObservationRegistry observationRegistry, ToolCallingManager toolCallingManager, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
        Assert.notNull((Object)bedrockRuntimeClient, (String)"bedrockRuntimeClient must not be null");
        Assert.notNull((Object)bedrockRuntimeAsyncClient, (String)"bedrockRuntimeAsyncClient must not be null");
        Assert.notNull((Object)toolCallingManager, (String)"toolCallingManager must not be null");
        Assert.notNull((Object)toolExecutionEligibilityPredicate, (String)"toolExecutionEligibilityPredicate must not be null");
        this.bedrockRuntimeClient = bedrockRuntimeClient;
        this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
        this.defaultOptions = defaultOptions;
        this.observationRegistry = observationRegistry;
        this.toolCallingManager = toolCallingManager;
        this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
    }

    private static BedrockChatOptions from(ChatOptions options) {
        return BedrockChatOptions.builder().model(options.getModel()).maxTokens(options.getMaxTokens()).stopSequences(options.getStopSequences()).temperature(options.getTemperature()).topP(options.getTopP()).build();
    }

    public ChatResponse call(Prompt prompt) {
        Prompt requestPrompt = this.buildRequestPrompt(prompt);
        return this.internalCall(requestPrompt, null);
    }

    private ChatResponse internalCall(Prompt prompt, ChatResponse perviousChatResponse) {
        ConverseRequest converseRequest = this.createRequest(prompt);
        ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(AiProvider.BEDROCK_CONVERSE.value()).build();
        ChatResponse chatResponse = (ChatResponse)ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry).observe(() -> {
            ConverseResponse converseResponse = this.bedrockRuntimeClient.converse(converseRequest);
            logger.debug("ConverseResponse: {}", (Object)converseResponse);
            ChatResponse response = this.toChatResponse(converseResponse, perviousChatResponse);
            observationContext.setResponse((Object)response);
            return response;
        });
        if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) {
            ToolExecutionResult toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
            if (toolExecutionResult.returnDirect()) {
                return ChatResponse.builder().from(chatResponse).generations(ToolExecutionResult.buildGenerations((ToolExecutionResult)toolExecutionResult)).build();
            }
            return this.internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), chatResponse);
        }
        return chatResponse;
    }

    public ChatOptions getDefaultOptions() {
        return this.defaultOptions;
    }

    Prompt buildRequestPrompt(Prompt prompt) {
        BedrockChatOptions runtimeOptions = null;
        if (prompt.getOptions() != null) {
            ChatOptions chatOptions = prompt.getOptions();
            if (chatOptions instanceof BedrockChatOptions) {
                BedrockChatOptions bedrockChatOptions = (BedrockChatOptions)chatOptions;
                runtimeOptions = bedrockChatOptions.copy();
            } else {
                chatOptions = prompt.getOptions();
                if (chatOptions instanceof ToolCallingChatOptions) {
                    ToolCallingChatOptions toolCallingChatOptions = (ToolCallingChatOptions)chatOptions;
                    runtimeOptions = (BedrockChatOptions)ModelOptionsUtils.copyToTarget((Object)toolCallingChatOptions, ToolCallingChatOptions.class, BedrockChatOptions.class);
                } else {
                    runtimeOptions = BedrockProxyChatModel.from(prompt.getOptions());
                }
            }
        }
        BedrockChatOptions updatedRuntimeOptions = null;
        if (runtimeOptions == null) {
            updatedRuntimeOptions = this.defaultOptions.copy();
        } else {
            if (runtimeOptions.getFrequencyPenalty() != null) {
                logger.warn("The frequencyPenalty option is not supported by BedrockProxyChatModel. Ignoring.");
            }
            if (runtimeOptions.getPresencePenalty() != null) {
                logger.warn("The presencePenalty option is not supported by BedrockProxyChatModel. Ignoring.");
            }
            if (runtimeOptions.getTopK() != null) {
                logger.warn("The topK option is not supported by BedrockProxyChatModel. Ignoring.");
            }
            updatedRuntimeOptions = BedrockChatOptions.builder().model(runtimeOptions.getModel() != null ? runtimeOptions.getModel() : this.defaultOptions.getModel()).maxTokens(runtimeOptions.getMaxTokens() != null ? runtimeOptions.getMaxTokens() : this.defaultOptions.getMaxTokens()).stopSequences(runtimeOptions.getStopSequences() != null ? runtimeOptions.getStopSequences() : this.defaultOptions.getStopSequences()).temperature(runtimeOptions.getTemperature() != null ? runtimeOptions.getTemperature() : this.defaultOptions.getTemperature()).topP(runtimeOptions.getTopP() != null ? runtimeOptions.getTopP() : this.defaultOptions.getTopP()).toolCallbacks(runtimeOptions.getToolCallbacks() != null ? runtimeOptions.getToolCallbacks() : this.defaultOptions.getToolCallbacks()).toolNames(runtimeOptions.getToolNames() != null ? runtimeOptions.getToolNames() : this.defaultOptions.getToolNames()).toolContext(runtimeOptions.getToolContext() != null ? runtimeOptions.getToolContext() : this.defaultOptions.getToolContext()).internalToolExecutionEnabled(runtimeOptions.getInternalToolExecutionEnabled() != null ? runtimeOptions.getInternalToolExecutionEnabled() : this.defaultOptions.getInternalToolExecutionEnabled()).cacheOptions(runtimeOptions.getCacheOptions() != null ? runtimeOptions.getCacheOptions() : this.defaultOptions.getCacheOptions()).build();
        }
        ToolCallingChatOptions.validateToolCallbacks(updatedRuntimeOptions.getToolCallbacks());
        return new Prompt(prompt.getInstructions(), (ChatOptions)updatedRuntimeOptions);
    }

    ConverseRequest createRequest(Prompt prompt) {
        boolean shouldCacheTools;
        boolean shouldCacheSystem;
        BedrockChatOptions updatedRuntimeOptions = (BedrockChatOptions)prompt.getOptions().copy();
        BedrockCacheOptions cacheOptions = updatedRuntimeOptions.getCacheOptions();
        boolean shouldCacheConversationHistory = cacheOptions != null && cacheOptions.getStrategy() == BedrockCacheStrategy.CONVERSATION_HISTORY;
        List<Message> allNonSystemMessages = prompt.getInstructions().stream().filter(message -> message.getMessageType() != MessageType.SYSTEM).toList();
        int lastUserMessageIndex = -1;
        if (shouldCacheConversationHistory) {
            for (int i = allNonSystemMessages.size() - 1; i >= 0; --i) {
                if (allNonSystemMessages.get(i).getMessageType() != MessageType.USER) continue;
                lastUserMessageIndex = i;
                break;
            }
            if (logger.isDebugEnabled()) {
                logger.debug("CONVERSATION_HISTORY caching: lastUserMessageIndex={}, totalMessages={}", (Object)lastUserMessageIndex, (Object)allNonSystemMessages.size());
            }
        }
        ArrayList<software.amazon.awssdk.services.bedrockruntime.model.Message> instructionMessages = new ArrayList<software.amazon.awssdk.services.bedrockruntime.model.Message>();
        for (int i = 0; i < allNonSystemMessages.size(); ++i) {
            boolean shouldApplyCachePoint;
            Message message2 = allNonSystemMessages.get(i);
            boolean bl = shouldApplyCachePoint = shouldCacheConversationHistory && i == lastUserMessageIndex;
            if (message2.getMessageType() == MessageType.USER) {
                ArrayList<ContentBlock> contents = new ArrayList<ContentBlock>();
                if (message2 instanceof UserMessage) {
                    UserMessage userMessage = (UserMessage)message2;
                    contents.add(ContentBlock.fromText((String)userMessage.getText()));
                    if (!CollectionUtils.isEmpty((Collection)userMessage.getMedia())) {
                        List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(this::mapMediaToContentBlock).toList();
                        contents.addAll(mediaContent);
                    }
                }
                if (shouldApplyCachePoint) {
                    CachePointBlock cachePoint = (CachePointBlock)CachePointBlock.builder().type("default").build();
                    contents.add(ContentBlock.fromCachePoint((CachePointBlock)cachePoint));
                    logger.debug("Applied cache point on last user message (conversation history caching)");
                }
                instructionMessages.add((software.amazon.awssdk.services.bedrockruntime.model.Message)software.amazon.awssdk.services.bedrockruntime.model.Message.builder().content(contents).role(ConversationRole.USER).build());
                continue;
            }
            if (message2.getMessageType() == MessageType.ASSISTANT) {
                AssistantMessage assistantMessage = (AssistantMessage)message2;
                ArrayList<ContentBlock> contentBlocks = new ArrayList<ContentBlock>();
                if (StringUtils.hasText((String)message2.getText())) {
                    contentBlocks.add(ContentBlock.fromText((String)message2.getText()));
                }
                if (!CollectionUtils.isEmpty((Collection)assistantMessage.getToolCalls())) {
                    for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
                        Document argumentsDocument = ConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap((String)toolCall.arguments()));
                        contentBlocks.add(ContentBlock.fromToolUse((ToolUseBlock)((ToolUseBlock)ToolUseBlock.builder().toolUseId(toolCall.id()).name(toolCall.name()).input(argumentsDocument).build())));
                    }
                }
                instructionMessages.add((software.amazon.awssdk.services.bedrockruntime.model.Message)software.amazon.awssdk.services.bedrockruntime.model.Message.builder().content(contentBlocks).role(ConversationRole.ASSISTANT).build());
                continue;
            }
            if (message2.getMessageType() == MessageType.TOOL) {
                ArrayList<ContentBlock> contentBlocks = new ArrayList<ContentBlock>(((ToolResponseMessage)message2).getResponses().stream().map(toolResponse -> {
                    ToolResultBlock toolResultBlock = (ToolResultBlock)ToolResultBlock.builder().toolUseId(toolResponse.id()).content(new ToolResultContentBlock[]{(ToolResultContentBlock)ToolResultContentBlock.builder().text(toolResponse.responseData()).build()}).build();
                    return ContentBlock.fromToolResult((ToolResultBlock)toolResultBlock);
                }).toList());
                instructionMessages.add((software.amazon.awssdk.services.bedrockruntime.model.Message)software.amazon.awssdk.services.bedrockruntime.model.Message.builder().content(contentBlocks).role(ConversationRole.USER).build());
                continue;
            }
            throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message2.getMessageType()));
        }
        boolean bl = shouldCacheSystem = cacheOptions != null && (cacheOptions.getStrategy() == BedrockCacheStrategy.SYSTEM_ONLY || cacheOptions.getStrategy() == BedrockCacheStrategy.SYSTEM_AND_TOOLS);
        if (logger.isDebugEnabled() && cacheOptions != null) {
            logger.debug("Cache strategy: {}, shouldCacheSystem: {}", (Object)cacheOptions.getStrategy(), (Object)shouldCacheSystem);
        }
        List<Message> systemMessageList = prompt.getInstructions().stream().filter(m -> m.getMessageType() == MessageType.SYSTEM).toList();
        ArrayList<SystemContentBlock> systemMessages = new ArrayList<SystemContentBlock>();
        for (int i = 0; i < systemMessageList.size(); ++i) {
            boolean isLastSystem;
            Message sysMessage = systemMessageList.get(i);
            SystemContentBlock textBlock = (SystemContentBlock)SystemContentBlock.builder().text(sysMessage.getText()).build();
            systemMessages.add(textBlock);
            boolean bl2 = isLastSystem = i == systemMessageList.size() - 1;
            if (!isLastSystem || !shouldCacheSystem) continue;
            CachePointBlock cachePoint = (CachePointBlock)CachePointBlock.builder().type("default").build();
            SystemContentBlock cachePointBlock = (SystemContentBlock)SystemContentBlock.builder().cachePoint(cachePoint).build();
            systemMessages.add(cachePointBlock);
            logger.debug("Applied cache point after system message");
        }
        ToolConfiguration toolConfiguration = null;
        List toolDefinitions = this.toolCallingManager.resolveToolDefinitions((ToolCallingChatOptions)updatedRuntimeOptions);
        boolean bl3 = shouldCacheTools = cacheOptions != null && (cacheOptions.getStrategy() == BedrockCacheStrategy.TOOLS_ONLY || cacheOptions.getStrategy() == BedrockCacheStrategy.SYSTEM_AND_TOOLS);
        if (!CollectionUtils.isEmpty((Collection)toolDefinitions)) {
            ArrayList<Tool> bedrockTools = new ArrayList<Tool>();
            for (int i = 0; i < toolDefinitions.size(); ++i) {
                boolean isLastTool;
                ToolDefinition toolDefinition = (ToolDefinition)toolDefinitions.get(i);
                String description = toolDefinition.description();
                String name = toolDefinition.name();
                String inputSchema = toolDefinition.inputSchema();
                Tool tool = (Tool)Tool.builder().toolSpec((ToolSpecification)ToolSpecification.builder().name(name).description(description).inputSchema(ToolInputSchema.fromJson((Document)ConverseApiUtils.convertObjectToDocument(ModelOptionsUtils.jsonToMap((String)inputSchema)))).build()).build();
                bedrockTools.add(tool);
                boolean bl4 = isLastTool = i == toolDefinitions.size() - 1;
                if (!isLastTool || !shouldCacheTools) continue;
                CachePointBlock cachePoint = (CachePointBlock)CachePointBlock.builder().type("default").build();
                Tool cachePointTool = (Tool)Tool.builder().cachePoint(cachePoint).build();
                bedrockTools.add(cachePointTool);
                logger.debug("Applied cache point after tool definitions");
            }
            toolConfiguration = (ToolConfiguration)ToolConfiguration.builder().tools(bedrockTools).build();
        }
        InferenceConfiguration inferenceConfiguration = (InferenceConfiguration)InferenceConfiguration.builder().maxTokens(updatedRuntimeOptions.getMaxTokens()).stopSequences(updatedRuntimeOptions.getStopSequences()).temperature(updatedRuntimeOptions.getTemperature() != null ? Float.valueOf(updatedRuntimeOptions.getTemperature().floatValue()) : null).topP(updatedRuntimeOptions.getTopP() != null ? Float.valueOf(updatedRuntimeOptions.getTopP().floatValue()) : null).build();
        Document additionalModelRequestFields = ConverseApiUtils.getChatOptionsAdditionalModelRequestFields((ChatOptions)this.defaultOptions, (ModelOptions)prompt.getOptions());
        Map<String, String> requestMetadata = ConverseApiUtils.getRequestMetadata(prompt.getUserMessage().getMetadata());
        return (ConverseRequest)ConverseRequest.builder().modelId(updatedRuntimeOptions.getModel()).inferenceConfig(inferenceConfiguration).messages(instructionMessages).system(systemMessages).additionalModelRequestFields(additionalModelRequestFields).toolConfig(toolConfiguration).requestMetadata(requestMetadata).build();
    }

    private ContentBlock mapMediaToContentBlock(Media media) {
        MimeType mimeType = media.getMimeType();
        if (BedrockMediaFormat.isSupportedVideoFormat(mimeType).booleanValue()) {
            VideoFormat videoFormat = BedrockMediaFormat.getVideoFormat(mimeType);
            VideoSource videoSource = null;
            Object object = media.getData();
            if (object instanceof byte[]) {
                byte[] bytes = (byte[])object;
                videoSource = (VideoSource)VideoSource.builder().bytes(SdkBytes.fromByteArrayUnsafe((byte[])bytes)).build();
            } else {
                object = media.getData();
                if (object instanceof String) {
                    String uriText = (String)object;
                    videoSource = (VideoSource)VideoSource.builder().s3Location((S3Location)S3Location.builder().uri(uriText).build()).build();
                } else {
                    object = media.getData();
                    if (object instanceof URL) {
                        URL url = (URL)object;
                        try {
                            videoSource = (VideoSource)VideoSource.builder().s3Location((S3Location)S3Location.builder().uri(url.toURI().toString()).build()).build();
                        }
                        catch (URISyntaxException e) {
                            throw new IllegalArgumentException(e);
                        }
                    } else {
                        throw new IllegalArgumentException("Invalid video content type: " + String.valueOf(media.getData().getClass()));
                    }
                }
            }
            return ContentBlock.fromVideo((VideoBlock)((VideoBlock)VideoBlock.builder().source(videoSource).format(videoFormat).build()));
        }
        if (BedrockMediaFormat.isSupportedImageFormat(mimeType).booleanValue()) {
            ImageSource.Builder sourceBuilder;
            block33: {
                sourceBuilder = ImageSource.builder();
                Object uriText = media.getData();
                if (uriText instanceof byte[]) {
                    byte[] bytes = (byte[])uriText;
                    sourceBuilder.bytes(SdkBytes.fromByteArrayUnsafe((byte[])bytes)).build();
                } else {
                    uriText = media.getData();
                    if (uriText instanceof String) {
                        String text = (String)uriText;
                        if (URLValidator.isValidURLBasic(text)) {
                            try {
                                URL url = new URL(text);
                                URLConnection connection = url.openConnection();
                                try (InputStream is = connection.getInputStream();){
                                    sourceBuilder.bytes(SdkBytes.fromByteArrayUnsafe((byte[])StreamUtils.copyToByteArray((InputStream)is))).build();
                                    break block33;
                                }
                            }
                            catch (IOException e) {
                                throw new RuntimeException("Failed to read media data from URL: " + text, e);
                            }
                        }
                        sourceBuilder.bytes(SdkBytes.fromByteArray((byte[])Base64.getDecoder().decode(text)));
                    } else {
                        Object e = media.getData();
                        if (e instanceof URL) {
                            URL url = (URL)e;
                            try (InputStream is = url.openConnection().getInputStream();){
                                byte[] imageBytes = StreamUtils.copyToByteArray((InputStream)is);
                                sourceBuilder.bytes(SdkBytes.fromByteArrayUnsafe((byte[])imageBytes)).build();
                                break block33;
                            }
                            catch (IOException e2) {
                                throw new IllegalArgumentException("Failed to read media data from URL: " + String.valueOf(url), e2);
                            }
                        }
                        throw new IllegalArgumentException("Invalid Image content type: " + String.valueOf(media.getData().getClass()));
                    }
                }
            }
            return ContentBlock.fromImage((ImageBlock)((ImageBlock)ImageBlock.builder().source((ImageSource)sourceBuilder.build()).format(BedrockMediaFormat.getImageFormat(mimeType)).build()));
        }
        if (BedrockMediaFormat.isSupportedDocumentFormat(mimeType).booleanValue()) {
            return ContentBlock.fromDocument((DocumentBlock)((DocumentBlock)DocumentBlock.builder().name(media.getName()).format(BedrockMediaFormat.getDocumentFormat(mimeType)).source((DocumentSource)DocumentSource.builder().bytes(SdkBytes.fromByteArray((byte[])media.getDataAsByteArray())).build()).build()));
        }
        throw new IllegalArgumentException("Unsupported media format: " + String.valueOf(mimeType));
    }

    private static byte[] getContentMediaData(Object mediaData) {
        if (mediaData instanceof byte[]) {
            byte[] bytes = (byte[])mediaData;
            return bytes;
        }
        if (mediaData instanceof String) {
            String text = (String)mediaData;
            if (URLValidator.isValidURLBasic(text)) {
                byte[] byArray;
                block20: {
                    URL url = new URL(text);
                    URLConnection connection = url.openConnection();
                    InputStream is = connection.getInputStream();
                    try {
                        byArray = StreamUtils.copyToByteArray((InputStream)is);
                        if (is == null) break block20;
                    }
                    catch (Throwable throwable) {
                        try {
                            if (is != null) {
                                try {
                                    is.close();
                                }
                                catch (Throwable throwable2) {
                                    throwable.addSuppressed(throwable2);
                                }
                            }
                            throw throwable;
                        }
                        catch (IOException e) {
                            throw new RuntimeException("Failed to read media data from URL: " + text, e);
                        }
                    }
                    is.close();
                }
                return byArray;
            }
            return text.getBytes();
        }
        if (mediaData instanceof URL) {
            byte[] byArray;
            block21: {
                URL url = (URL)mediaData;
                InputStream is = url.openConnection().getInputStream();
                try {
                    byArray = StreamUtils.copyToByteArray((InputStream)is);
                    if (is == null) break block21;
                }
                catch (Throwable throwable) {
                    try {
                        if (is != null) {
                            try {
                                is.close();
                            }
                            catch (Throwable throwable3) {
                                throwable.addSuppressed(throwable3);
                            }
                        }
                        throw throwable;
                    }
                    catch (IOException e) {
                        throw new RuntimeException("Failed to read media data from URL: " + String.valueOf(url), e);
                    }
                }
                is.close();
            }
            return byArray;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + mediaData.getClass().getSimpleName());
    }

    private ChatResponse toChatResponse(ConverseResponse response, ChatResponse perviousChatResponse) {
        List<ContentBlock> toolUseContentBlocks;
        Assert.notNull((Object)response, (String)"'response' must not be null.");
        software.amazon.awssdk.services.bedrockruntime.model.Message message = response.output().message();
        List<Generation> generations = message.content().stream().filter(content -> content.type() != ContentBlock.Type.TOOL_USE).map(content -> new Generation(AssistantMessage.builder().content(content.text()).properties(Map.of()).build(), ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build())).toList();
        ArrayList<Generation> allGenerations = new ArrayList<Generation>(generations);
        if (response.stopReasonAsString() != null && generations.isEmpty()) {
            Generation generation = new Generation(AssistantMessage.builder().properties(Map.of()).build(), ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build());
            allGenerations.add(generation);
        }
        if (!CollectionUtils.isEmpty(toolUseContentBlocks = message.content().stream().filter(c -> c.type() == ContentBlock.Type.TOOL_USE).toList())) {
            ArrayList<AssistantMessage.ToolCall> toolCalls = new ArrayList<AssistantMessage.ToolCall>();
            for (ContentBlock toolUseContentBlock : toolUseContentBlocks) {
                String functionCallId = toolUseContentBlock.toolUse().toolUseId();
                String functionName = toolUseContentBlock.toolUse().name();
                String functionArguments = toolUseContentBlock.toolUse().input().toString();
                toolCalls.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
            }
            AssistantMessage assistantMessage = AssistantMessage.builder().content("").properties(Map.of()).toolCalls(toolCalls).build();
            Generation toolCallGeneration = new Generation(assistantMessage, ChatGenerationMetadata.builder().finishReason(response.stopReasonAsString()).build());
            allGenerations.add(toolCallGeneration);
        }
        Integer promptTokens = response.usage().inputTokens();
        Integer generationTokens = response.usage().outputTokens();
        int totalTokens = response.usage().totalTokens();
        if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null && perviousChatResponse.getMetadata().getUsage() != null) {
            promptTokens = promptTokens + perviousChatResponse.getMetadata().getUsage().getPromptTokens();
            generationTokens = generationTokens + perviousChatResponse.getMetadata().getUsage().getCompletionTokens();
            totalTokens += perviousChatResponse.getMetadata().getUsage().getTotalTokens().intValue();
        }
        DefaultUsage usage = new DefaultUsage(promptTokens, generationTokens, Integer.valueOf(totalTokens));
        Document modelResponseFields = response.additionalModelResponseFields();
        ConverseMetrics metrics = response.metrics();
        ChatResponseMetadata.Builder metadataBuilder = ChatResponseMetadata.builder().id(response.responseMetadata() != null ? response.responseMetadata().requestId() : "Unknown").usage((Usage)usage);
        HashMap<String, Integer> additionalMetadata = new HashMap<String, Integer>();
        if (response.usage().cacheReadInputTokens() != null) {
            additionalMetadata.put("cacheReadInputTokens", response.usage().cacheReadInputTokens());
        }
        if (response.usage().cacheWriteInputTokens() != null) {
            additionalMetadata.put("cacheWriteInputTokens", response.usage().cacheWriteInputTokens());
        }
        if (!additionalMetadata.isEmpty()) {
            metadataBuilder.metadata(additionalMetadata);
        }
        return new ChatResponse(allGenerations, metadataBuilder.build());
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        Prompt requestPrompt = this.buildRequestPrompt(prompt);
        return this.internalStream(requestPrompt, null);
    }

    private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse perviousChatResponse) {
        Assert.notNull((Object)prompt, (String)"'prompt' must not be null");
        return Flux.deferContextual(contextView -> {
            ConverseRequest converseRequest = this.createRequest(prompt);
            ChatModelObservationContext observationContext = ChatModelObservationContext.builder().prompt(prompt).provider(AiProvider.BEDROCK_CONVERSE.value()).build();
            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation((ObservationConvention)this.observationConvention, (ObservationConvention)DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);
            observation.parentObservation((Observation)contextView.getOrDefault((Object)"micrometer.observation", null)).start();
            ConverseStreamRequest converseStreamRequest = (ConverseStreamRequest)ConverseStreamRequest.builder().modelId(converseRequest.modelId()).inferenceConfig(converseRequest.inferenceConfig()).messages((Collection)converseRequest.messages()).system((Collection)converseRequest.system()).additionalModelRequestFields(converseRequest.additionalModelRequestFields()).toolConfig(converseRequest.toolConfig()).requestMetadata(converseRequest.requestMetadata()).build();
            Usage accumulatedUsage = null;
            if (perviousChatResponse != null && perviousChatResponse.getMetadata() != null) {
                accumulatedUsage = perviousChatResponse.getMetadata().getUsage();
            }
            Flux<ChatResponse> chatResponses = new ConverseChatResponseStream(this.bedrockRuntimeAsyncClient, converseStreamRequest, accumulatedUsage).stream();
            Flux chatResponseFlux = chatResponses.switchMap(chatResponse -> {
                if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse) && chatResponse.hasFinishReasons(Set.of(StopReason.TOOL_USE.toString()))) {
                    return Flux.deferContextual(ctx -> {
                        ToolExecutionResult toolExecutionResult;
                        try {
                            ToolCallReactiveContextHolder.setContext((ContextView)ctx);
                            toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
                        }
                        finally {
                            ToolCallReactiveContextHolder.clearContext();
                        }
                        if (toolExecutionResult.returnDirect()) {
                            return Flux.just((Object)ChatResponse.builder().from(chatResponse).generations(ToolExecutionResult.buildGenerations((ToolExecutionResult)toolExecutionResult)).build());
                        }
                        return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), (ChatResponse)chatResponse);
                    }).subscribeOn(Schedulers.boundedElastic());
                }
                return Flux.just((Object)chatResponse);
            }).doOnError(arg_0 -> ((Observation)observation).error(arg_0)).doFinally(s -> observation.stop()).contextWrite(ctx -> ctx.put((Object)"micrometer.observation", (Object)observation));
            return new MessageAggregator().aggregate(chatResponseFlux, arg_0 -> ((ChatModelObservationContext)observationContext).setResponse(arg_0));
        });
    }

    public void setObservationConvention(ChatModelObservationConvention observationConvention) {
        Assert.notNull((Object)observationConvention, (String)"observationConvention cannot be null");
        this.observationConvention = observationConvention;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private AwsCredentialsProvider credentialsProvider;
        private Region region = Region.US_EAST_1;
        private Duration timeout = Duration.ofMinutes(5L);
        private Duration connectionTimeout = Duration.ofSeconds(5L);
        private Duration asyncReadTimeout = Duration.ofSeconds(30L);
        private Duration connectionAcquisitionTimeout = Duration.ofSeconds(30L);
        private Duration socketTimeout = Duration.ofSeconds(30L);
        private ToolCallingManager toolCallingManager;
        private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate();
        private BedrockChatOptions defaultOptions = BedrockChatOptions.builder().build();
        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
        private ChatModelObservationConvention customObservationConvention;
        private BedrockRuntimeClient bedrockRuntimeClient;
        private BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;

        private Builder() {
            try {
                this.region = DefaultAwsRegionProviderChain.builder().build().getRegion();
            }
            catch (SdkClientException e) {
                logger.warn("Failed to load region from DefaultAwsRegionProviderChain, using US_EAST_1", (Throwable)e);
            }
        }

        public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
            this.toolCallingManager = toolCallingManager;
            return this;
        }

        public Builder toolExecutionEligibilityPredicate(ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
            this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
            return this;
        }

        public Builder credentialsProvider(AwsCredentialsProvider credentialsProvider) {
            Assert.notNull((Object)credentialsProvider, (String)"'credentialsProvider' must not be null.");
            this.credentialsProvider = credentialsProvider;
            return this;
        }

        public Builder region(Region region) {
            Assert.notNull((Object)region, (String)"'region' must not be null.");
            this.region = region;
            return this;
        }

        public Builder timeout(Duration timeout) {
            Assert.notNull((Object)timeout, (String)"'timeout' must not be null.");
            this.timeout = timeout;
            return this;
        }

        public Builder connectionTimeout(Duration connectionTimeout) {
            Assert.notNull((Object)connectionTimeout, (String)"'connectionTimeout' must not be null.");
            this.connectionTimeout = connectionTimeout;
            return this;
        }

        public Builder asyncReadTimeout(Duration asyncReadTimeout) {
            Assert.notNull((Object)asyncReadTimeout, (String)"'asyncReadTimeout' must not be null.");
            this.asyncReadTimeout = asyncReadTimeout;
            return this;
        }

        public Builder connectionAcquisitionTimeout(Duration connectionAcquisitionTimeout) {
            Assert.notNull((Object)connectionAcquisitionTimeout, (String)"'connectionAcquisitionTimeout' must not be null.");
            this.connectionAcquisitionTimeout = connectionAcquisitionTimeout;
            return this;
        }

        public Builder socketTimeout(Duration socketTimeout) {
            Assert.notNull((Object)socketTimeout, (String)"'socketTimeout' must not be null.");
            this.socketTimeout = socketTimeout;
            return this;
        }

        public Builder defaultOptions(BedrockChatOptions defaultOptions) {
            Assert.notNull((Object)defaultOptions, (String)"'defaultOptions' must not be null.");
            this.defaultOptions = defaultOptions;
            return this;
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            Assert.notNull((Object)observationRegistry, (String)"'observationRegistry' must not be null.");
            this.observationRegistry = observationRegistry;
            return this;
        }

        public Builder customObservationConvention(ChatModelObservationConvention observationConvention) {
            Assert.notNull((Object)observationConvention, (String)"'observationConvention' must not be null.");
            this.customObservationConvention = observationConvention;
            return this;
        }

        public Builder bedrockRuntimeClient(BedrockRuntimeClient bedrockRuntimeClient) {
            this.bedrockRuntimeClient = bedrockRuntimeClient;
            return this;
        }

        public Builder bedrockRuntimeAsyncClient(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient) {
            this.bedrockRuntimeAsyncClient = bedrockRuntimeAsyncClient;
            return this;
        }

        public BedrockProxyChatModel build() {
            ApacheHttpClient.Builder httpClientBuilder;
            if (this.bedrockRuntimeClient == null) {
                httpClientBuilder = ApacheHttpClient.builder().connectionAcquisitionTimeout(this.connectionAcquisitionTimeout).connectionTimeout(this.connectionTimeout).socketTimeout(this.socketTimeout);
                this.bedrockRuntimeClient = (BedrockRuntimeClient)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)BedrockRuntimeClient.builder().region(this.region)).httpClientBuilder((SdkHttpClient.Builder)httpClientBuilder)).credentialsProvider(this.credentialsProvider)).overrideConfiguration(c -> c.apiCallTimeout(this.timeout))).build();
            }
            if (this.bedrockRuntimeAsyncClient == null) {
                httpClientBuilder = NettyNioAsyncHttpClient.builder().tcpKeepAlive(Boolean.valueOf(true)).readTimeout(this.asyncReadTimeout).connectionTimeout(this.connectionTimeout).connectionAcquisitionTimeout(this.connectionAcquisitionTimeout).maxConcurrency(Integer.valueOf(200));
                BedrockRuntimeAsyncClientBuilder builder = (BedrockRuntimeAsyncClientBuilder)((BedrockRuntimeAsyncClientBuilder)((BedrockRuntimeAsyncClientBuilder)((BedrockRuntimeAsyncClientBuilder)BedrockRuntimeAsyncClient.builder().region(this.region)).httpClientBuilder((SdkAsyncHttpClient.Builder)httpClientBuilder)).credentialsProvider(this.credentialsProvider)).overrideConfiguration(c -> c.apiCallTimeout(this.timeout));
                this.bedrockRuntimeAsyncClient = (BedrockRuntimeAsyncClient)builder.build();
            }
            BedrockProxyChatModel bedrockProxyChatModel = null;
            bedrockProxyChatModel = this.toolCallingManager != null ? new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry, this.toolCallingManager, this.toolExecutionEligibilityPredicate) : new BedrockProxyChatModel(this.bedrockRuntimeClient, this.bedrockRuntimeAsyncClient, this.defaultOptions, this.observationRegistry, DEFAULT_TOOL_CALLING_MANAGER, this.toolExecutionEligibilityPredicate);
            if (this.customObservationConvention != null) {
                bedrockProxyChatModel.setObservationConvention(this.customObservationConvention);
            }
            return bedrockProxyChatModel;
        }
    }
}

