/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.dataprepper.plugins.ml_inference.processor.common;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.common.utils.RetryUtil;
import org.opensearch.dataprepper.logging.DataPrepperMarkers;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventKey;
import org.opensearch.dataprepper.model.failures.DlqObject;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessor;
import org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessorConfig;
import org.opensearch.dataprepper.plugins.ml_inference.processor.common.AbstractBatchJobCreator;
import org.opensearch.dataprepper.plugins.ml_inference.processor.dlq.DlqPushHandler;
import org.opensearch.dataprepper.plugins.ml_inference.processor.exception.MLBatchJobException;
import org.slf4j.Logger;

public class BedrockBatchJobCreator
extends AbstractBatchJobCreator {
    private final AwsCredentialsSupplier awsCredentialsSupplier;
    private final ConcurrentLinkedQueue<AbstractBatchJobCreator.RetryRecord> throttledRecords = new ConcurrentLinkedQueue();
    private final Lock processingLock;
    private static final String BEDROCK_PAYLOAD_TEMPLATE = "{\"parameters\": {\"inputDataConfig\": {\"s3InputDataConfig\": {\"s3Uri\": \"s3://\"}},\"jobName\": \"\", \"outputDataConfig\": {\"s3OutputDataConfig\": {\"s3Uri\": \"s3://\"}}}}";

    public BedrockBatchJobCreator(MLProcessorConfig mlProcessorConfig, AwsCredentialsSupplier awsCredentialsSupplier, PluginMetrics pluginMetrics, DlqPushHandler dlqPushHandler) {
        super(mlProcessorConfig, awsCredentialsSupplier, pluginMetrics, dlqPushHandler);
        this.awsCredentialsSupplier = awsCredentialsSupplier;
        this.processingLock = new ReentrantLock();
    }

    @Override
    public void createMLBatchJob(List<Record<Event>> inputRecords, List<Record<Event>> resultRecords) {
        this.processRecords(inputRecords, resultRecords, null);
    }

    private void processRecords(List<Record<Event>> records, List<Record<Event>> resultRecords, List<AbstractBatchJobCreator.RetryRecord> throttledRecords) {
        ArrayList<Record<Event>> failedRecords = new ArrayList<Record<Event>>();
        ArrayList<DlqObject> dlqObjects = new ArrayList<DlqObject>();
        for (int i = 0; i < records.size(); ++i) {
            Record<Event> record = records.get(i);
            AbstractBatchJobCreator.RetryRecord throttledRecord = throttledRecords != null ? throttledRecords.get(i) : null;
            this.processRecord(record, resultRecords, failedRecords, dlqObjects, throttledRecord);
            try {
                Thread.sleep(1L);
                continue;
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                MLProcessor.LOG.debug("Interrupted during sleep: " + e.getMessage());
            }
        }
        if (!failedRecords.isEmpty()) {
            this.pushToDlq(dlqObjects);
            this.numberOfRecordsFailedCounter.increment((double)dlqObjects.size());
            throw new MLBatchJobException(String.format("Failed to process %d records out of %d total records", failedRecords.size(), records.size()), new Throwable("Batch job processing failed due to one or more failed records"));
        }
    }

    private void processRecord(Record<Event> record, List<Record<Event>> resultRecords, List<Record<Event>> failedRecords, List<DlqObject> dlqObjects, AbstractBatchJobCreator.RetryRecord throttledRecord) {
        try {
            String s3Uri = this.generateS3Uri(record);
            String payload = this.createPayloadBedrock(s3Uri, this.mlProcessorConfig);
            RetryUtil.RetryResult result = RetryUtil.retryWithBackoffWithResult(() -> this.mlCommonRequester.sendRequestToMLCommons(payload), (Logger)MLProcessor.LOG);
            if (result.isSuccess()) {
                String logMessage = throttledRecord != null ? String.format("Successfully retried Bedrock batch job for the S3Uri: %s (attempt %d)", s3Uri, throttledRecord.getRetryCount()) : String.format("Successfully created Bedrock batch job for the S3Uri: %s", s3Uri);
                MLProcessor.LOG.info(logMessage);
                resultRecords.add(record);
                this.numberOfRecordsSuccessCounter.increment();
                this.incrementSuccessCounter();
            } else {
                Exception e = result.getLastException();
                String errorMessage = String.format("Failed to %s Bedrock batch job%s for S3Uri: %s. Error: %s", throttledRecord != null ? "retry" : "create", throttledRecord != null ? String.format(" (attempt %d)", throttledRecord.getRetryCount()) : "", s3Uri, e.getMessage());
                int statusCode = 500;
                if (e instanceof MLBatchJobException) {
                    MLBatchJobException mlException = (MLBatchJobException)e;
                    statusCode = mlException.getStatusCode();
                    if (this.shouldRetry(statusCode, mlException.getMessage())) {
                        AbstractBatchJobCreator.RetryRecord newThrottledRecord = throttledRecord != null ? throttledRecord : new AbstractBatchJobCreator.RetryRecord(this, record);
                        this.throttledRecords.offer(newThrottledRecord);
                        MLProcessor.LOG.info("Request {} throttled{}, added to retry queue: {}", new Object[]{throttledRecord != null ? "still" : "", throttledRecord != null ? String.format(" (attempt %d)", throttledRecord.getRetryCount()) : "", s3Uri});
                        return;
                    }
                    MLProcessor.LOG.error(DataPrepperMarkers.NOISY, errorMessage);
                } else {
                    MLProcessor.LOG.error(DataPrepperMarkers.NOISY, errorMessage, (Throwable)e);
                }
                this.handleFailure(record, resultRecords, failedRecords, dlqObjects, e, statusCode);
            }
        }
        catch (IllegalArgumentException e) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Invalid arguments for BedRock batch job. Error: {}", (Object)e.getMessage());
            this.handleFailure(record, resultRecords, failedRecords, dlqObjects, e, 400);
        }
        catch (Exception e) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Unexpected Error occurred while processing batch job through BedRock. Error: {}", (Object)e.getMessage(), (Object)e);
            this.handleFailure(record, resultRecords, failedRecords, dlqObjects, e, 500);
        }
    }

    @Override
    public void addProcessedBatchRecordsToResults(List<Record<Event>> resultRecords) {
        if (!this.processingLock.tryLock()) {
            MLProcessor.LOG.debug("Another thread is currently processing results, skipping this attempt");
            return;
        }
        try {
            this.processThrottledRecords(resultRecords);
        }
        catch (Exception e) {
            MLProcessor.LOG.error("Error in batch processing throttled records. Error: {}", (Object)e.getMessage());
        }
        finally {
            this.processingLock.unlock();
        }
    }

    private boolean shouldRetry(int statusCode, String errorMessage) {
        if (statusCode == 429) {
            return true;
        }
        if (errorMessage == null) {
            return false;
        }
        return statusCode == 400 && (errorMessage.contains("quota for number of concurrent invoke-model jobs") || errorMessage.contains("throttling") || errorMessage.contains("request was denied due to remote server throttling"));
    }

    private void handleFailure(Record<Event> record, List<Record<Event>> resultRecords, List<Record<Event>> failedRecords, List<DlqObject> dlqObjects, Throwable throwable, int statusCode) {
        resultRecords.addAll(this.addFailureTags(Collections.singletonList(record)));
        this.incrementFailureCounter();
        failedRecords.add(record);
        if (this.dlqPushHandler == null) {
            return;
        }
        try {
            if (record.getData() != null) {
                dlqObjects.add(this.createDlqObjectFromEvent((Event)record.getData(), statusCode, throwable.getMessage()));
            }
        }
        catch (Exception ex) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Exception occured during error handling: {}", (Object)ex.getMessage());
        }
    }

    private void pushToDlq(List<DlqObject> dlqObjects) {
        if (this.dlqPushHandler == null || dlqObjects.isEmpty()) {
            return;
        }
        try {
            this.dlqPushHandler.perform(dlqObjects);
            MLProcessor.LOG.info("Successfully pushed {} failed records to DLQ", (Object)dlqObjects.size());
        }
        catch (Exception e) {
            MLProcessor.LOG.error("Failed to push {} records to DLQ: {}", new Object[]{dlqObjects.size(), e.getMessage(), e});
        }
    }

    private void processThrottledRecords(List<Record<Event>> resultRecords) {
        AbstractBatchJobCreator.RetryRecord throttledRecord;
        ArrayList<AbstractBatchJobCreator.RetryRecord> expiredRecords = new ArrayList<AbstractBatchJobCreator.RetryRecord>();
        ArrayList<AbstractBatchJobCreator.RetryRecord> recordsToRetry = new ArrayList<AbstractBatchJobCreator.RetryRecord>();
        while ((throttledRecord = this.throttledRecords.poll()) != null) {
            if (throttledRecord.isExpired()) {
                expiredRecords.add(throttledRecord);
                continue;
            }
            recordsToRetry.add(throttledRecord);
            throttledRecord.incrementRetryCount();
        }
        this.handleExpiredRecords(expiredRecords, resultRecords);
        this.retryThrottledRecords(recordsToRetry, resultRecords);
    }

    private void retryThrottledRecords(List<AbstractBatchJobCreator.RetryRecord> recordsToRetry, List<Record<Event>> resultRecords) {
        if (recordsToRetry.isEmpty()) {
            return;
        }
        MLProcessor.LOG.info("Retrying {} throttled records", (Object)recordsToRetry.size());
        this.processRecords(recordsToRetry.stream().map(AbstractBatchJobCreator.RetryRecord::getRecord).collect(Collectors.toCollection(ArrayList::new)), resultRecords, recordsToRetry);
    }

    private void handleExpiredRecords(List<AbstractBatchJobCreator.RetryRecord> expiredRecords, List<Record<Event>> resultRecords) {
        if (expiredRecords.isEmpty()) {
            return;
        }
        ArrayList<Record<Event>> failedRecords = new ArrayList<Record<Event>>();
        ArrayList<DlqObject> dlqObjects = new ArrayList<DlqObject>();
        for (AbstractBatchJobCreator.RetryRecord expiredRecord : expiredRecords) {
            String errorMessage = String.format("Record expired after %d retries over %d minutes", expiredRecord.getRetryCount(), this.maxRetryTimeWindow / 60000L);
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Record expired from throttle queue: {}", (Object)errorMessage);
            this.handleFailure(expiredRecord.getRecord(), resultRecords, failedRecords, dlqObjects, new MLBatchJobException(408, errorMessage), 408);
        }
        if (!failedRecords.isEmpty()) {
            this.pushToDlq(dlqObjects);
            this.numberOfRecordsFailedCounter.increment((double)dlqObjects.size());
        }
    }

    private String generateS3Uri(Record<Event> record) {
        String bucket = Optional.ofNullable(((Event)record.getData()).getJsonNode().path("bucket").asText(null)).orElseThrow(() -> new IllegalArgumentException("Missing 'bucket' in record."));
        EventKey inputKey = this.mlProcessorConfig.getInputKey();
        String key = Optional.ofNullable(inputKey == null ? ((Event)record.getData()).getJsonNode().path("key").asText(null) : (String)((Event)record.getData()).get(inputKey, String.class)).orElseThrow(() -> new IllegalArgumentException("Missing 'S3 Key' in record."));
        return "s3://" + bucket + "/" + key;
    }

    private String createPayloadBedrock(String S3Uri, MLProcessorConfig mlProcessorConfig) {
        if (S3Uri == null || S3Uri.isEmpty()) {
            throw new IllegalArgumentException("Invalid S3Uri: S3Uri is either null or empty. Please ensure the correct input S3 uris are provided");
        }
        String jobName = this.generateJobName();
        try {
            JsonNode rootNode = OBJECT_MAPPER.readTree(BEDROCK_PAYLOAD_TEMPLATE);
            ((ObjectNode)rootNode.at("/parameters/inputDataConfig/s3InputDataConfig")).put("s3Uri", S3Uri);
            ((ObjectNode)rootNode.at("/parameters")).put("jobName", jobName);
            ((ObjectNode)rootNode.at("/parameters/outputDataConfig/s3OutputDataConfig")).put("s3Uri", mlProcessorConfig.getOutputPath());
            return OBJECT_MAPPER.writeValueAsString((Object)rootNode);
        }
        catch (Exception e) {
            MLProcessor.LOG.error("Failed to create BedRock batch job payload with input {}.", (Object)S3Uri, (Object)e);
            throw new RuntimeException("Failed to create payload for BedRock batch job", e);
        }
    }

    public ConcurrentLinkedQueue<AbstractBatchJobCreator.RetryRecord> getThrottledRecords() {
        return this.throttledRecords;
    }
}

