/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.ModelProvider;
import dev.langchain4j.model.bedrock.internal.Json;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.output.Response;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;

public abstract class AbstractSharedBedrockChatModel {
    private static final Logger log = LoggerFactory.getLogger(AbstractSharedBedrockChatModel.class);
    protected static final String HUMAN_PROMPT = "Human:";
    protected static final String ASSISTANT_PROMPT = "Assistant:";
    protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31";
    protected static final Integer DEFAULT_MAX_RETRIES = 2;
    protected static final Region DEFAULT_REGION = Region.US_EAST_1;
    protected static final AwsCredentialsProvider DEFAULT_CREDENTIALS_PROVIDER = DefaultCredentialsProvider.builder().build();
    protected static final int DEFAULT_MAX_TOKENS = 300;
    protected static final double DEFAULT_TEMPERATURE = 1.0;
    protected static final float DEFAULT_TOP_P = 0.999f;
    protected static final String[] DEFAULT_STOP_SEQUENCES = new String[0];
    protected static final int DEFAULT_TOP_K = 250;
    protected static final Duration DEFAULT_TIMEOUT = Duration.ofMinutes(1L);
    protected static final List<ChatModelListener> DEFAULT_LISTENERS = Collections.emptyList();
    protected final String humanPrompt;
    protected final String assistantPrompt;
    protected final Integer maxRetries;
    protected final Region region;
    protected final AwsCredentialsProvider credentialsProvider;
    protected final int maxTokens;
    protected final double temperature;
    protected final float topP;
    protected final String[] stopSequences;
    protected final int topK;
    protected final Duration timeout;
    protected final String anthropicVersion;
    protected final List<ChatModelListener> listeners;

    protected AbstractSharedBedrockChatModel(AbstractSharedBedrockChatModelBuilder<?, ?> builder) {
        this.humanPrompt = builder.isHumanPromptSet ? builder.humanPrompt : HUMAN_PROMPT;
        this.assistantPrompt = builder.isAssistantPromptSet ? builder.assistantPrompt : ASSISTANT_PROMPT;
        this.maxRetries = builder.isMaxRetriesSet ? builder.maxRetries : DEFAULT_MAX_RETRIES;
        this.region = builder.isRegionSet ? builder.region : DEFAULT_REGION;
        this.credentialsProvider = builder.isCredentialsProviderSet ? builder.credentialsProvider : DEFAULT_CREDENTIALS_PROVIDER;
        this.maxTokens = builder.isMaxTokensSet ? builder.maxTokens : 300;
        this.temperature = builder.isTemperatureSet ? builder.temperature : 1.0;
        this.topP = builder.isTopPSet ? builder.topP : 0.999f;
        this.stopSequences = builder.isStopSequencesSet ? builder.stopSequences : DEFAULT_STOP_SEQUENCES;
        this.topK = builder.isTopKSet ? builder.topK : 250;
        this.timeout = builder.isTimeoutSet ? builder.timeout : DEFAULT_TIMEOUT;
        this.anthropicVersion = builder.isAnthropicVersionSet ? builder.anthropicVersion : DEFAULT_ANTHROPIC_VERSION;
        this.listeners = builder.isListenersSet ? builder.listeners : DEFAULT_LISTENERS;
    }

    protected String chatMessageToString(ChatMessage message) {
        if (message instanceof SystemMessage) {
            SystemMessage systemMessage = (SystemMessage)message;
            return systemMessage.text();
        }
        if (message instanceof UserMessage) {
            UserMessage userMessage = (UserMessage)message;
            return this.humanPrompt + " " + userMessage.singleText();
        }
        if (message instanceof AiMessage) {
            AiMessage aiMessage = (AiMessage)message;
            return this.assistantPrompt + " " + aiMessage.text();
        }
        throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message.type()));
    }

    protected String convertMessagesToAwsBody(List<ChatMessage> messages) {
        String context = messages.stream().filter(message -> message.type() == ChatMessageType.SYSTEM).map(message -> ((SystemMessage)message).text()).collect(Collectors.joining("\n"));
        String userMessages = messages.stream().filter(message -> message.type() != ChatMessageType.SYSTEM).map(this::chatMessageToString).collect(Collectors.joining("\n"));
        String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
        Map<String, Object> requestParameters = this.getRequestParameters(prompt);
        String body = Json.toJson(requestParameters);
        return body;
    }

    protected Map<String, Object> getRequestParameters(String prompt) {
        HashMap<String, Object> parameters = new HashMap<String, Object>(7);
        parameters.put("prompt", prompt);
        parameters.put("max_tokens_to_sample", this.getMaxTokens());
        parameters.put("temperature", this.getTemperature());
        parameters.put("top_k", this.topK);
        parameters.put("top_p", Float.valueOf(this.getTopP()));
        parameters.put("stop_sequences", this.getStopSequences());
        parameters.put("anthropic_version", this.anthropicVersion);
        return parameters;
    }

    protected void listenerErrorResponse(Throwable e, ChatRequest listenerRequest, ModelProvider modelProvider, Map<Object, Object> attributes) {
        Throwable error = e.getCause() instanceof SdkClientException ? e.getCause() : e;
        ChatModelErrorContext errorContext = new ChatModelErrorContext(error, listenerRequest, modelProvider, attributes);
        this.listeners.forEach(listener -> {
            try {
                listener.onError(errorContext);
            }
            catch (Exception e2) {
                log.warn("Exception while calling model listener", (Throwable)e2);
            }
        });
    }

    protected ChatRequest createListenerRequest(InvokeModelRequest invokeModelRequest, List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        return ChatRequest.builder().messages(messages).parameters(ChatRequestParameters.builder().modelName(invokeModelRequest.modelId()).temperature(Double.valueOf(this.temperature)).topP(Double.valueOf(this.topP)).maxOutputTokens(Integer.valueOf(this.maxTokens)).toolSpecifications(toolSpecifications).build()).build();
    }

    protected ChatRequest createListenerRequest(InvokeModelWithResponseStreamRequest invokeModelRequest, List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        return ChatRequest.builder().messages(messages).parameters(ChatRequestParameters.builder().modelName(invokeModelRequest.modelId()).temperature(Double.valueOf(this.temperature)).topP(Double.valueOf(this.topP)).maxOutputTokens(Integer.valueOf(this.maxTokens)).toolSpecifications(toolSpecifications).build()).build();
    }

    protected ChatResponse createListenerResponse(String responseId, String responseModel, Response<AiMessage> response) {
        if (response == null) {
            return null;
        }
        return ChatResponse.builder().aiMessage((AiMessage)response.content()).metadata(ChatResponseMetadata.builder().id(responseId).modelName(responseModel).tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build()).build();
    }

    protected abstract String getModelId();

    public String getHumanPrompt() {
        return this.humanPrompt;
    }

    public String getAssistantPrompt() {
        return this.assistantPrompt;
    }

    public Integer getMaxRetries() {
        return this.maxRetries;
    }

    public Region getRegion() {
        return this.region;
    }

    public AwsCredentialsProvider getCredentialsProvider() {
        return this.credentialsProvider;
    }

    public int getMaxTokens() {
        return this.maxTokens;
    }

    public double getTemperature() {
        return this.temperature;
    }

    public float getTopP() {
        return this.topP;
    }

    public String[] getStopSequences() {
        return this.stopSequences;
    }

    public int getTopK() {
        return this.topK;
    }

    public Duration getTimeout() {
        return this.timeout;
    }

    public String getAnthropicVersion() {
        return this.anthropicVersion;
    }

    public List<ChatModelListener> getListeners() {
        return this.listeners;
    }

    public static abstract class AbstractSharedBedrockChatModelBuilder<C extends AbstractSharedBedrockChatModel, B extends AbstractSharedBedrockChatModelBuilder<C, B>> {
        private boolean isHumanPromptSet;
        private String humanPrompt;
        private boolean isAssistantPromptSet;
        private String assistantPrompt;
        private boolean isMaxRetriesSet;
        private Integer maxRetries;
        private boolean isRegionSet;
        private Region region;
        private boolean isCredentialsProviderSet;
        private AwsCredentialsProvider credentialsProvider;
        private boolean isMaxTokensSet;
        private int maxTokens;
        private boolean isTemperatureSet;
        private double temperature;
        private boolean isTopPSet;
        private float topP;
        private boolean isStopSequencesSet;
        private String[] stopSequences;
        private boolean isTopKSet;
        private int topK;
        private boolean isTimeoutSet;
        private Duration timeout;
        private boolean isAnthropicVersionSet;
        private String anthropicVersion;
        private boolean isListenersSet;
        private List<ChatModelListener> listeners;

        public B humanPrompt(String humanPrompt) {
            this.humanPrompt = humanPrompt;
            this.isHumanPromptSet = true;
            return this.self();
        }

        public B assistantPrompt(String assistantPrompt) {
            this.assistantPrompt = assistantPrompt;
            this.isAssistantPromptSet = true;
            return this.self();
        }

        public B maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            this.isMaxRetriesSet = true;
            return this.self();
        }

        public B region(Region region) {
            this.region = region;
            this.isRegionSet = true;
            return this.self();
        }

        public B credentialsProvider(AwsCredentialsProvider credentialsProvider) {
            this.credentialsProvider = credentialsProvider;
            this.isCredentialsProviderSet = true;
            return this.self();
        }

        public B maxTokens(int maxTokens) {
            this.maxTokens = maxTokens;
            this.isMaxTokensSet = true;
            return this.self();
        }

        public B temperature(double temperature) {
            this.temperature = temperature;
            this.isTemperatureSet = true;
            return this.self();
        }

        public B topP(float topP) {
            this.topP = topP;
            this.isTopPSet = true;
            return this.self();
        }

        public B stopSequences(String[] stopSequences) {
            this.stopSequences = stopSequences;
            this.isStopSequencesSet = true;
            return this.self();
        }

        public B topK(int topK) {
            this.topK = topK;
            this.isTopKSet = true;
            return this.self();
        }

        public B timeout(Duration timeout) {
            this.timeout = timeout;
            this.isTimeoutSet = true;
            return this.self();
        }

        public B anthropicVersion(String anthropicVersion) {
            this.anthropicVersion = anthropicVersion;
            this.isAnthropicVersionSet = true;
            return this.self();
        }

        public B listeners(List<ChatModelListener> listeners) {
            this.listeners = listeners;
            this.isListenersSet = true;
            return this.self();
        }

        protected abstract B self();

        public abstract C build();

        public String toString() {
            return "AbstractSharedBedrockChatModel.AbstractSharedBedrockChatModelBuilder(humanPrompt$value=" + this.humanPrompt + ", assistantPrompt$value=" + this.assistantPrompt + ", maxRetries$value=" + this.maxRetries + ", region$value=" + String.valueOf(this.region) + ", credentialsProvider$value=" + String.valueOf(this.credentialsProvider) + ", maxTokens$value=" + this.maxTokens + ", temperature$value=" + this.temperature + ", topP$value=" + this.topP + ", stopSequences$value=" + Arrays.deepToString(this.stopSequences) + ", topK$value=" + this.topK + ", timeout$value=" + String.valueOf(this.timeout) + ", anthropicVersion$value=" + this.anthropicVersion + ", listeners$value=" + String.valueOf(this.listeners) + ")";
        }
    }
}

