/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.dataprepper.plugins.ml_inference.processor.common;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import org.json.JSONArray;
import org.json.JSONObject;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.common.utils.RetryUtil;
import org.opensearch.dataprepper.logging.DataPrepperMarkers;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventKey;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessor;
import org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessorConfig;
import org.opensearch.dataprepper.plugins.ml_inference.processor.client.S3ClientFactory;
import org.opensearch.dataprepper.plugins.ml_inference.processor.common.AbstractBatchJobCreator;
import org.slf4j.Logger;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;

public class SageMakerBatchJobCreator
extends AbstractBatchJobCreator {
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final AwsCredentialsSupplier awsCredentialsSupplier;
    private final S3Client s3Client;
    private final DateTimeFormatter dateTimeFormatter;
    private final ConcurrentLinkedQueue<Record<Event>> batch_records = new ConcurrentLinkedQueue();
    private final ConcurrentLinkedQueue<Record<Event>> processedBatchRecords = new ConcurrentLinkedQueue();
    private final int MAX_BATCH_SIZE;
    private final AtomicLong lastUpdateTimestamp = new AtomicLong(-1L);
    private final long INACTIVITY_TIMEOUT_MS = 60000L;
    private static final String SAGEMAKER_PAYLOAD_TEMPLATE = "{\"parameters\":{\"TransformInput\":{\"ContentType\":\"application/json\",\"DataSource\":{\"S3DataSource\":{\"S3DataType\":\"ManifestFile\",\"S3Uri\":\"\"}},\"SplitType\":\"Line\"},\"TransformJobName\":\"\",\"TransformOutput\":{\"AssembleWith\":\"Line\",\"Accept\":\"application/json\",\"S3OutputPath\":\"s3://\"}}}";

    public SageMakerBatchJobCreator(MLProcessorConfig mlProcessorConfig, AwsCredentialsSupplier awsCredentialsSupplier, PluginMetrics pluginMetrics) {
        super(mlProcessorConfig, awsCredentialsSupplier, pluginMetrics);
        this.awsCredentialsSupplier = awsCredentialsSupplier;
        this.s3Client = S3ClientFactory.createS3Client(mlProcessorConfig, awsCredentialsSupplier);
        this.dateTimeFormatter = DateTimeFormatter.ofPattern("yyyyMMddHHmmss");
        this.MAX_BATCH_SIZE = mlProcessorConfig.getMaxBatchSize();
    }

    @Override
    public void createMLBatchJob(List<Record<Event>> inputRecords, List<Record<Event>> resultRecords) {
        if (inputRecords.isEmpty()) {
            return;
        }
        this.batch_records.addAll(inputRecords);
        this.lastUpdateTimestamp.set(System.currentTimeMillis());
        MLProcessor.LOG.info("Added {} records to batch. Current batch size: {}", (Object)inputRecords.size(), (Object)this.batch_records.size());
    }

    @Override
    public void addProcessedBatchRecordsToResults(List<Record<Event>> resultRecords) {
        if (!this.processedBatchRecords.isEmpty()) {
            resultRecords.addAll(this.processedBatchRecords);
            MLProcessor.LOG.info("Result records updated: {} processed records added, new total size: {}", (Object)this.processedBatchRecords.size(), (Object)resultRecords.size());
            this.processedBatchRecords.clear();
        }
    }

    @Override
    public void checkAndProcessBatch() {
        try {
            if (this.batch_records.isEmpty()) {
                return;
            }
            boolean shouldProcess = false;
            long currentTime = System.currentTimeMillis();
            long lastUpdate = this.lastUpdateTimestamp.get();
            if (this.batch_records.size() >= this.MAX_BATCH_SIZE) {
                shouldProcess = true;
                MLProcessor.LOG.info("Processing batch due to size limit reached: {}", (Object)this.batch_records.size());
            } else if (lastUpdate != -1L && currentTime - lastUpdate >= 60000L) {
                shouldProcess = true;
                MLProcessor.LOG.info("Processing batch due to inactivity timeout. Time since last update: {} ms", (Object)(currentTime - lastUpdate));
            }
            if (shouldProcess) {
                ArrayList<Record<Event>> currentBatch = new ArrayList<Record<Event>>(this.batch_records);
                this.batch_records.clear();
                this.lastUpdateTimestamp.set(-1L);
                this.processCurrentBatch(currentBatch);
            }
        }
        catch (Exception e) {
            MLProcessor.LOG.error("Error in batch processing check: ", (Throwable)e);
        }
    }

    private void processCurrentBatch(List<Record<Event>> currentBatch) {
        try {
            String customerBucket = currentBatch.stream().findAny().map(record -> ((Event)record.getData()).getJsonNode().get("bucket").asText()).orElse(null);
            String commonPrefix = this.findCommonPrefix(currentBatch);
            String manifestUrl = this.generateManifest(currentBatch, customerBucket, commonPrefix);
            String payload = this.createPayloadSageMaker(manifestUrl, this.mlProcessorConfig);
            boolean success = RetryUtil.retryWithBackoff(() -> this.mlCommonRequester.sendRequestToMLCommons(payload), (Logger)MLProcessor.LOG);
            if (success) {
                MLProcessor.LOG.info("Successfully created SageMaker batch job for manifest URL: {}", (Object)manifestUrl);
                this.processedBatchRecords.addAll(currentBatch);
                this.incrementSuccessCounter();
            } else {
                this.handleFailure(currentBatch, this.processedBatchRecords);
                MLProcessor.LOG.error("SageMaker batch job failed after multiple retries for manifest URL: {}", (Object)manifestUrl);
            }
        }
        catch (IllegalArgumentException e) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Invalid arguments for SageMaker batch job. Error: {}", (Object)e.getMessage());
            this.handleFailure(currentBatch, this.processedBatchRecords);
        }
        catch (RuntimeException e) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Runtime Exception for SageMaker batch job. Error: {}", (Object)e.getMessage());
            this.handleFailure(currentBatch, this.processedBatchRecords);
        }
        catch (Exception e) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Unexpected Error occurred while creating a batch job through SageMaker: {}", (Object)e.getMessage(), (Object)e);
            this.handleFailure(currentBatch, this.processedBatchRecords);
        }
    }

    private void handleFailure(List<Record<Event>> record, ConcurrentLinkedQueue<Record<Event>> resultRecords) {
        resultRecords.addAll(this.addFailureTags(record));
        this.incrementFailureCounter();
    }

    @Override
    public void prepareForShutdown() {
    }

    @Override
    public boolean isReadyForShutdown() {
        return this.batch_records.isEmpty();
    }

    @Override
    public void shutdown() {
        this.processRemainingBatch();
        this.prepareForShutdown();
    }

    private void processRemainingBatch() {
        if (!this.batch_records.isEmpty()) {
            ArrayList<Record<Event>> currentBatch = new ArrayList<Record<Event>>(this.batch_records);
            this.batch_records.clear();
            this.processCurrentBatch(currentBatch);
        }
    }

    private String findCommonPrefix(Collection<Record<Event>> records) {
        EventKey inputKey = this.mlProcessorConfig.getInputKey();
        List keys = records.stream().map(record -> inputKey == null ? ((Event)record.getData()).getJsonNode().get("key").asText() : (String)((Event)record.getData()).get(inputKey, String.class)).collect(Collectors.toList());
        if (keys.isEmpty()) {
            throw new IllegalArgumentException("Empty inputs identified from input key : " + String.valueOf(inputKey));
        }
        if (keys.size() == 1) {
            String singleKey = (String)keys.get(0);
            int lastSlashIndex = singleKey.lastIndexOf(47);
            return lastSlashIndex >= 0 ? singleKey.substring(0, lastSlashIndex + 1) : "";
        }
        String prefix = (String)keys.get(0);
        for (int i = 1; i < keys.size() && !(prefix = this.findCommonPrefix(prefix, (String)keys.get(i))).isEmpty(); ++i) {
        }
        return prefix;
    }

    private String findCommonPrefix(String s1, String s2) {
        int i;
        int minLength = Math.min(s1.length(), s2.length());
        for (i = 0; i < minLength && s1.charAt(i) == s2.charAt(i); ++i) {
        }
        int lastSlashIndex = s1.lastIndexOf(47, i - 1);
        return lastSlashIndex >= 0 ? s1.substring(0, lastSlashIndex + 1) : "";
    }

    private String generateManifest(Collection<Record<Event>> records, String customerBucket, String prefix) {
        try {
            String timestamp = LocalDateTime.now().format(this.dateTimeFormatter);
            String folderName = prefix + "batch-" + timestamp;
            String fileName = folderName + "/batch-" + timestamp + ".manifest";
            JSONArray manifestArray = new JSONArray();
            manifestArray.put((Object)new JSONObject().put("prefix", (Object)("s3://" + customerBucket + "/")));
            for (Record<Event> record : records) {
                String key = ((Event)record.getData()).getJsonNode().get("key").asText();
                manifestArray.put((Object)key);
            }
            byte[] jsonData = manifestArray.toString(4).getBytes();
            PutObjectRequest putObjectRequest = (PutObjectRequest)PutObjectRequest.builder().bucket(customerBucket).key(fileName).build();
            this.s3Client.putObject(putObjectRequest, RequestBody.fromBytes((byte[])jsonData));
            return "s3://" + customerBucket + "/" + fileName;
        }
        catch (Exception e) {
            MLProcessor.LOG.error("Unexpected error while generating manifest file for SageMaker job.", (Throwable)e);
            return null;
        }
    }

    private String createPayloadSageMaker(String manifestUri, MLProcessorConfig mlProcessorConfig) {
        if (manifestUri == null || manifestUri.isEmpty()) {
            throw new IllegalArgumentException("Invalid manifest URI: manifestUri is either null or empty. Please ensure the correct input S3 uris are provided");
        }
        try {
            String jobName = this.generateJobName();
            String outputPath = mlProcessorConfig.getOutputPath();
            if (outputPath != null) {
                outputPath = outputPath.concat(outputPath.endsWith("/") ? "" : "/").concat(jobName);
            }
            JsonNode rootNode = OBJECT_MAPPER.readTree(SAGEMAKER_PAYLOAD_TEMPLATE);
            ((ObjectNode)rootNode.at("/parameters/TransformInput/DataSource/S3DataSource")).put("S3Uri", manifestUri);
            ((ObjectNode)rootNode.at("/parameters")).put("TransformJobName", jobName);
            if (outputPath != null) {
                ((ObjectNode)rootNode.at("/parameters/TransformOutput")).put("S3OutputPath", outputPath);
            } else {
                ((ObjectNode)rootNode).remove("parameters").path("TransformOutput");
            }
            return OBJECT_MAPPER.writeValueAsString((Object)rootNode);
        }
        catch (JsonProcessingException e) {
            MLProcessor.LOG.error("Failed to process the JSON payload for SageMaker batch job. Error: {}", (Object)e.getMessage());
            throw new RuntimeException("Error processing JSON payload for SageMaker batch job", e);
        }
        catch (Exception e) {
            MLProcessor.LOG.error("Failed to create SageMaker batch job payload with input {}.", (Object)manifestUri, (Object)e);
            throw new RuntimeException("Failed to create payload for SageMaker batch job", e);
        }
    }
}

