package org.springframework.ai.bedrock.cohere;

import java.util.List;
import org.springframework.ai.bedrock.BedrockUsage;
import org.springframework.ai.bedrock.MessageToPromptConverter;
import org.springframework.ai.bedrock.cohere.api.CohereChatBedrockApi;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;

/* loaded from: input_file:org/springframework/ai/bedrock/cohere/BedrockCohereChatModel.class */
public class BedrockCohereChatModel implements ChatModel, StreamingChatModel {
    private final CohereChatBedrockApi chatApi;
    private final BedrockCohereChatOptions defaultOptions;

    public BedrockCohereChatModel(CohereChatBedrockApi cohereChatBedrockApi) {
        this(cohereChatBedrockApi, BedrockCohereChatOptions.builder().build());
    }

    public BedrockCohereChatModel(CohereChatBedrockApi cohereChatBedrockApi, BedrockCohereChatOptions bedrockCohereChatOptions) {
        Assert.notNull(cohereChatBedrockApi, "CohereChatBedrockApi must not be null");
        Assert.notNull(bedrockCohereChatOptions, "BedrockCohereChatOptions must not be null");
        this.chatApi = cohereChatBedrockApi;
        this.defaultOptions = bedrockCohereChatOptions;
    }

    public ChatResponse call(Prompt prompt) {
        return new ChatResponse(this.chatApi.chatCompletion(createRequest(prompt, false)).generations().stream().map(generation -> {
            return new Generation(generation.text());
        }).toList());
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        return this.chatApi.chatCompletionStream(createRequest(prompt, true)).map(generation -> {
            if (!generation.isFinished().booleanValue()) {
                return new ChatResponse(List.of(new Generation(generation.text())));
            }
            return new ChatResponse(List.of(new Generation("").withGenerationMetadata(ChatGenerationMetadata.from(generation.finishReason().name(), BedrockUsage.from(generation.amazonBedrockInvocationMetrics())))));
        });
    }

    CohereChatBedrockApi.CohereChatRequest createRequest(Prompt prompt, boolean z) {
        CohereChatBedrockApi.CohereChatRequest build = CohereChatBedrockApi.CohereChatRequest.builder(MessageToPromptConverter.create().toPrompt(prompt.getInstructions())).withTemperature(this.defaultOptions.getTemperature()).withTopP(this.defaultOptions.getTopP()).withTopK(this.defaultOptions.getTopK()).withMaxTokens(this.defaultOptions.getMaxTokens()).withStopSequences(this.defaultOptions.getStopSequences()).withReturnLikelihoods(this.defaultOptions.getReturnLikelihoods()).withStream(z).withNumGenerations(this.defaultOptions.getNumGenerations()).withLogitBias(this.defaultOptions.getLogitBias()).withTruncate(this.defaultOptions.getTruncate()).build();
        if (prompt.getOptions() != null) {
            ChatOptions options = prompt.getOptions();
            if (!(options instanceof ChatOptions)) {
                throw new IllegalArgumentException("Prompt options are not of type ChatOptions: " + prompt.getOptions().getClass().getSimpleName());
            }
            build = (CohereChatBedrockApi.CohereChatRequest) ModelOptionsUtils.merge((BedrockCohereChatOptions) ModelOptionsUtils.copyToTarget(options, ChatOptions.class, BedrockCohereChatOptions.class), build, CohereChatBedrockApi.CohereChatRequest.class);
        }
        return build;
    }

    public ChatOptions getDefaultOptions() {
        return BedrockCohereChatOptions.fromOptions(this.defaultOptions);
    }
}
