package org.springframework.ai.bedrock.anthropic3;

import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;

/* loaded from: input_file:org/springframework/ai/bedrock/anthropic3/BedrockAnthropic3ChatModel.class */
public class BedrockAnthropic3ChatModel implements ChatModel, StreamingChatModel {
    private final Anthropic3ChatBedrockApi anthropicChatApi;
    private final Anthropic3ChatOptions defaultOptions;

    public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi anthropic3ChatBedrockApi) {
        this(anthropic3ChatBedrockApi, Anthropic3ChatOptions.builder().withTemperature(Double.valueOf(0.8d)).withMaxTokens(500).withTopK(10).withAnthropicVersion("bedrock-2023-05-31").build());
    }

    public BedrockAnthropic3ChatModel(Anthropic3ChatBedrockApi anthropic3ChatBedrockApi, Anthropic3ChatOptions anthropic3ChatOptions) {
        this.anthropicChatApi = anthropic3ChatBedrockApi;
        this.defaultOptions = anthropic3ChatOptions;
    }

    public ChatResponse call(Prompt prompt) {
        Anthropic3ChatBedrockApi.AnthropicChatResponse chatCompletion = this.anthropicChatApi.chatCompletion(createRequest(prompt));
        return new ChatResponse(chatCompletion.content().stream().map(mediaContent -> {
            return new Generation(new AssistantMessage(mediaContent.text()), ChatGenerationMetadata.from(chatCompletion.stopReason(), (Object) null));
        }).toList(), ChatResponseMetadata.builder().withId(chatCompletion.id()).withModel(chatCompletion.model()).withUsage(extractUsage(chatCompletion)).build());
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        Flux<Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse> chatCompletionStream = this.anthropicChatApi.chatCompletionStream(createRequest(prompt));
        AtomicReference atomicReference = new AtomicReference(0);
        return chatCompletionStream.map(anthropicChatStreamingResponse -> {
            if (anthropicChatStreamingResponse.type() == Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType.MESSAGE_START) {
                atomicReference.set(anthropicChatStreamingResponse.message().usage().inputTokens());
            }
            String text = anthropicChatStreamingResponse.type() == Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType.CONTENT_BLOCK_DELTA ? anthropicChatStreamingResponse.delta().text() : "";
            ChatGenerationMetadata chatGenerationMetadata = null;
            if (anthropicChatStreamingResponse.type() == Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType.MESSAGE_DELTA) {
                chatGenerationMetadata = ChatGenerationMetadata.from(anthropicChatStreamingResponse.delta().stopReason(), new Anthropic3ChatBedrockApi.AnthropicUsage((Integer) atomicReference.get(), anthropicChatStreamingResponse.usage().outputTokens()));
            }
            return new ChatResponse(List.of(new Generation(new AssistantMessage(text), chatGenerationMetadata)));
        });
    }

    protected Usage extractUsage(Anthropic3ChatBedrockApi.AnthropicChatResponse anthropicChatResponse) {
        return new DefaultUsage(Long.valueOf(anthropicChatResponse.usage().inputTokens().longValue()), Long.valueOf(anthropicChatResponse.usage().outputTokens().longValue()));
    }

    Anthropic3ChatBedrockApi.AnthropicChatRequest createRequest(Prompt prompt) {
        Anthropic3ChatBedrockApi.AnthropicChatRequest build = Anthropic3ChatBedrockApi.AnthropicChatRequest.builder(toAnthropicMessages(prompt)).withSystem(toAnthropicSystemContext(prompt)).build();
        if (this.defaultOptions != null) {
            build = (Anthropic3ChatBedrockApi.AnthropicChatRequest) ModelOptionsUtils.merge(build, this.defaultOptions, Anthropic3ChatBedrockApi.AnthropicChatRequest.class);
        }
        if (prompt.getOptions() != null) {
            build = (Anthropic3ChatBedrockApi.AnthropicChatRequest) ModelOptionsUtils.merge((Anthropic3ChatOptions) ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, Anthropic3ChatOptions.class), build, Anthropic3ChatBedrockApi.AnthropicChatRequest.class);
        }
        return build;
    }

    private String toAnthropicSystemContext(Prompt prompt) {
        return (String) prompt.getInstructions().stream().filter(message -> {
            return message.getMessageType() == MessageType.SYSTEM;
        }).map((v0) -> {
            return v0.getContent();
        }).collect(Collectors.joining(System.lineSeparator()));
    }

    private List<Anthropic3ChatBedrockApi.ChatCompletionMessage> toAnthropicMessages(Prompt prompt) {
        return prompt.getInstructions().stream().filter(message -> {
            return message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.ASSISTANT;
        }).map(message2 -> {
            ArrayList arrayList = new ArrayList(List.of(new Anthropic3ChatBedrockApi.MediaContent(message2.getContent())));
            if (message2 instanceof UserMessage) {
                UserMessage userMessage = (UserMessage) message2;
                if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
                    arrayList.addAll(userMessage.getMedia().stream().map(media -> {
                        return new Anthropic3ChatBedrockApi.MediaContent(media.getMimeType().toString(), fromMediaData(media.getData()));
                    }).toList());
                }
            }
            return new Anthropic3ChatBedrockApi.ChatCompletionMessage(arrayList, Anthropic3ChatBedrockApi.ChatCompletionMessage.Role.valueOf(message2.getMessageType().name()));
        }).toList();
    }

    private String fromMediaData(Object obj) {
        if (obj instanceof byte[]) {
            return Base64.getEncoder().encodeToString((byte[]) obj);
        }
        if (obj instanceof String) {
            return (String) obj;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + obj.getClass().getSimpleName());
    }

    public ChatOptions getDefaultOptions() {
        return Anthropic3ChatOptions.fromOptions(this.defaultOptions);
    }
}
