package org.springframework.ai.bedrock.titan;

import java.util.List;
import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.titan.api.TitanChatBedrockApi;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;

/* loaded from: input_file:org/springframework/ai/bedrock/titan/BedrockTitanChatModel.class */
public class BedrockTitanChatModel implements ChatModel, StreamingChatModel {
    private final TitanChatBedrockApi chatApi;
    private final BedrockTitanChatOptions defaultOptions;

    public BedrockTitanChatModel(TitanChatBedrockApi titanChatBedrockApi) {
        this(titanChatBedrockApi, BedrockTitanChatOptions.builder().withTemperature(Double.valueOf(0.8d)).build());
    }

    public BedrockTitanChatModel(TitanChatBedrockApi titanChatBedrockApi, BedrockTitanChatOptions bedrockTitanChatOptions) {
        Assert.notNull(titanChatBedrockApi, "ChatApi must not be null");
        Assert.notNull(bedrockTitanChatOptions, "DefaultOptions must not be null");
        this.chatApi = titanChatBedrockApi;
        this.defaultOptions = bedrockTitanChatOptions;
    }

    public ChatResponse call(Prompt prompt) {
        return new ChatResponse(this.chatApi.chatCompletion(createRequest(prompt)).results().stream().map(result -> {
            return new Generation(new AssistantMessage(result.outputText()));
        }).toList());
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        return this.chatApi.chatCompletionStream(createRequest(prompt)).map(titanChatResponseChunk -> {
            ChatGenerationMetadata chatGenerationMetadata = null;
            if (titanChatResponseChunk.amazonBedrockInvocationMetrics() != null) {
                chatGenerationMetadata = ChatGenerationMetadata.from(titanChatResponseChunk.completionReason().name(), titanChatResponseChunk.amazonBedrockInvocationMetrics());
            } else if (titanChatResponseChunk.inputTextTokenCount() != null && titanChatResponseChunk.totalOutputTextTokenCount() != null) {
                chatGenerationMetadata = ChatGenerationMetadata.from(titanChatResponseChunk.completionReason().name(), extractUsage(titanChatResponseChunk));
            }
            return new ChatResponse(List.of(new Generation(new AssistantMessage(titanChatResponseChunk.outputText()), chatGenerationMetadata)));
        });
    }

    TitanChatBedrockApi.TitanChatRequest createRequest(Prompt prompt) {
        TitanChatBedrockApi.TitanChatRequest.Builder builder = TitanChatBedrockApi.TitanChatRequest.builder(MessageToPromptConverter.create().toPrompt(prompt.getInstructions()));
        if (this.defaultOptions != null) {
            builder = update(builder, this.defaultOptions);
        }
        if (prompt.getOptions() != null) {
            builder = update(builder, (BedrockTitanChatOptions) ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, BedrockTitanChatOptions.class));
        }
        return builder.build();
    }

    private TitanChatBedrockApi.TitanChatRequest.Builder update(TitanChatBedrockApi.TitanChatRequest.Builder builder, BedrockTitanChatOptions bedrockTitanChatOptions) {
        if (bedrockTitanChatOptions.getTemperature() != null) {
            builder.withTemperature(bedrockTitanChatOptions.getTemperature());
        }
        if (bedrockTitanChatOptions.getTopP() != null) {
            builder.withTopP(bedrockTitanChatOptions.getTopP());
        }
        if (bedrockTitanChatOptions.getMaxTokenCount() != null) {
            builder.withMaxTokenCount(bedrockTitanChatOptions.getMaxTokenCount());
        }
        if (bedrockTitanChatOptions.getStopSequences() != null) {
            builder.withStopSequences(bedrockTitanChatOptions.getStopSequences());
        }
        return builder;
    }

    private Usage extractUsage(final TitanChatBedrockApi.TitanChatResponseChunk titanChatResponseChunk) {
        return new Usage() { // from class: org.springframework.ai.bedrock.titan.BedrockTitanChatModel.1
            public Long getPromptTokens() {
                return Long.valueOf(titanChatResponseChunk.inputTextTokenCount().longValue());
            }

            public Long getGenerationTokens() {
                return Long.valueOf(titanChatResponseChunk.totalOutputTextTokenCount().longValue());
            }
        };
    }

    public ChatOptions getDefaultOptions() {
        return BedrockTitanChatOptions.fromOptions(this.defaultOptions);
    }
}
