/*
 * Decompiled with CFR 0.152.
 */
package com.azure.storage.blob.specialized.cryptography;

import com.azure.core.cryptography.AsyncKeyEncryptionKey;
import com.azure.core.cryptography.AsyncKeyEncryptionKeyResolver;
import com.azure.core.http.HttpHeaderName;
import com.azure.core.http.HttpHeaders;
import com.azure.core.http.HttpMethod;
import com.azure.core.http.HttpPipelineCallContext;
import com.azure.core.http.HttpPipelineNextPolicy;
import com.azure.core.http.HttpResponse;
import com.azure.core.http.policy.HttpPipelinePolicy;
import com.azure.core.util.FluxUtil;
import com.azure.core.util.logging.ClientLogger;
import com.azure.storage.blob.models.BlobRange;
import com.azure.storage.blob.specialized.cryptography.CryptographyConstants;
import com.azure.storage.blob.specialized.cryptography.Decryptor;
import com.azure.storage.blob.specialized.cryptography.EncryptedBlobRange;
import com.azure.storage.blob.specialized.cryptography.EncryptionData;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.concurrent.atomic.AtomicLong;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class BlobDecryptionPolicy
implements HttpPipelinePolicy {
    private static final ClientLogger LOGGER = new ClientLogger(BlobDecryptionPolicy.class);
    private final AsyncKeyEncryptionKeyResolver keyResolver;
    private final AsyncKeyEncryptionKey keyWrapper;
    private final boolean requiresEncryption;

    BlobDecryptionPolicy(AsyncKeyEncryptionKey key, AsyncKeyEncryptionKeyResolver keyResolver, boolean requiresEncryption) {
        this.keyWrapper = key;
        this.keyResolver = keyResolver;
        this.requiresEncryption = requiresEncryption;
    }

    public Mono<HttpResponse> process(HttpPipelineCallContext context, HttpPipelineNextPolicy next) {
        HttpHeaders requestHeaders = context.getHttpRequest().getHeaders();
        String initialRangeHeader = requestHeaders.getValue(CryptographyConstants.RANGE_HEADER);
        if (!this.isRangeRequest(initialRangeHeader)) {
            return next.process().flatMap(httpResponse -> {
                if (this.isDownloadResponse((HttpResponse)httpResponse)) {
                    HttpHeaders responseHeaders = httpResponse.getHeaders();
                    EncryptionData encryptionData = EncryptionData.getAndValidateEncryptionData(httpResponse.getHeaderValue(CryptographyConstants.ENCRYPTION_METADATA_HEADER), this.requiresEncryption);
                    if (!this.isEncryptedBlob(encryptionData)) {
                        return Mono.just((Object)httpResponse);
                    }
                    EncryptedBlobRange encryptedRange = new EncryptedBlobRange(new BlobRange(0L), encryptionData);
                    encryptedRange.setAdjustedDownloadCount(Long.parseLong(responseHeaders.getValue(HttpHeaderName.CONTENT_LENGTH)));
                    boolean padding = this.hasPadding(responseHeaders, encryptionData, encryptedRange);
                    Flux<ByteBuffer> plainTextData = this.decryptBlob((Flux<ByteBuffer>)httpResponse.getBody(), encryptedRange, padding, encryptionData, httpResponse.getRequest().getUrl());
                    return Mono.just((Object)((Object)new DecryptedResponse((HttpResponse)httpResponse, plainTextData)));
                }
                return Mono.just((Object)httpResponse);
            });
        }
        if (!this.isEncryptedBlob(context)) {
            return this.validateEncryptionDataConsistency((Mono<HttpResponse>)next.process());
        }
        EncryptionData encryptionData = (EncryptionData)context.getData("encryptiondata").get();
        EncryptedBlobRange encryptedRange = EncryptedBlobRange.getEncryptedBlobRangeFromHeader(initialRangeHeader, encryptionData);
        if (context.getHttpRequest().getHeaders().getValue(CryptographyConstants.RANGE_HEADER) != null) {
            requestHeaders.set(CryptographyConstants.RANGE_HEADER, encryptedRange.toBlobRange().toString());
        }
        return next.process().map(httpResponse -> {
            if (this.isDownloadResponse((HttpResponse)httpResponse)) {
                HttpHeaders responseHeaders = httpResponse.getHeaders();
                if (httpResponse.getHeaderValue(CryptographyConstants.ENCRYPTION_METADATA_HEADER) == null) {
                    return httpResponse;
                }
                encryptedRange.setAdjustedDownloadCount(Long.parseLong(responseHeaders.getValue(HttpHeaderName.CONTENT_LENGTH)));
                boolean padding = this.hasPadding(httpResponse.getHeaders(), encryptionData, encryptedRange);
                Flux<ByteBuffer> plainTextData = this.decryptBlob((Flux<ByteBuffer>)httpResponse.getBody(), encryptedRange, padding, encryptionData, httpResponse.getRequest().getUrl());
                return new DecryptedResponse((HttpResponse)httpResponse, plainTextData);
            }
            return httpResponse;
        });
    }

    private boolean isRangeRequest(String rangeHeader) {
        return rangeHeader != null;
    }

    private boolean isEncryptedBlob(HttpPipelineCallContext context) {
        return context.getData("encryptiondata").isPresent();
    }

    private boolean isEncryptedBlob(EncryptionData encryptionData) {
        return encryptionData != null;
    }

    private Mono<HttpResponse> validateEncryptionDataConsistency(Mono<HttpResponse> responseMono) {
        return responseMono.map(response -> {
            if (response.getHeaderValue(CryptographyConstants.ENCRYPTION_METADATA_HEADER) != null) {
                throw LOGGER.logExceptionAsError((RuntimeException)new IllegalStateException("GetProperties did not find encryption data, but download request returned encryption data."));
            }
            return response;
        });
    }

    private boolean hasPadding(HttpHeaders responseHeaders, EncryptionData encryptionData, EncryptedBlobRange encryptedRange) {
        return encryptionData.getEncryptionAgent().getProtocol().equals("1.0") && encryptedRange.toBlobRange().getOffset() + encryptedRange.toBlobRange().getCount() > this.blobSize(responseHeaders) - 16L;
    }

    private boolean isDownloadResponse(HttpResponse httpResponse) {
        return httpResponse.getRequest().getHttpMethod() == HttpMethod.GET && httpResponse.getBody() != null;
    }

    Flux<ByteBuffer> decryptBlob(Flux<ByteBuffer> encryptedFlux, EncryptedBlobRange encryptedBlobRange, boolean padding, EncryptionData encryptionData, URL requestUri) {
        String uriToLog = requestUri.getHost() + requestUri.getPath();
        AtomicLong totalInputBytes = new AtomicLong(0L);
        AtomicLong totalOutputBytes = new AtomicLong(0L);
        Decryptor decryptor = Decryptor.getDecryptor(this.keyResolver, this.keyWrapper, encryptionData);
        Flux dataToTrim = decryptor.getKeyEncryptionKey().flatMapMany(key -> decryptor.decrypt(encryptedFlux, encryptedBlobRange, padding, uriToLog, totalInputBytes, (byte[])key));
        return this.trimData(encryptedBlobRange, totalOutputBytes, (Flux<ByteBuffer>)dataToTrim);
    }

    Flux<ByteBuffer> trimData(EncryptedBlobRange encryptedBlobRange, AtomicLong totalOutputBytes, Flux<ByteBuffer> dataToTrim) {
        return dataToTrim.map(plaintextByteBuffer -> {
            int decryptedBytes = plaintextByteBuffer.limit();
            if (totalOutputBytes.longValue() <= (long)encryptedBlobRange.getAmountPlaintextToSkip()) {
                int remainingAdjustment = encryptedBlobRange.getAmountPlaintextToSkip() - (int)totalOutputBytes.longValue();
                int newPosition = Math.min(remainingAdjustment, plaintextByteBuffer.limit());
                plaintextByteBuffer.position(newPosition);
            }
            long beginningOfEndAdjustment = encryptedBlobRange.getOriginalRange().getCount() == null ? Long.MAX_VALUE : (long)encryptedBlobRange.getAmountPlaintextToSkip() + encryptedBlobRange.getOriginalRange().getCount();
            if ((long)decryptedBytes + totalOutputBytes.longValue() > beginningOfEndAdjustment) {
                long amountPastEnd = (long)decryptedBytes + totalOutputBytes.longValue() - beginningOfEndAdjustment;
                int newLimit = totalOutputBytes.longValue() <= beginningOfEndAdjustment ? decryptedBytes - (int)amountPastEnd : plaintextByteBuffer.position();
                plaintextByteBuffer.limit(newLimit);
            } else if ((long)decryptedBytes + totalOutputBytes.longValue() > (long)encryptedBlobRange.getAmountPlaintextToSkip()) {
                plaintextByteBuffer.limit(decryptedBytes);
            } else {
                plaintextByteBuffer.limit(plaintextByteBuffer.position());
            }
            totalOutputBytes.addAndGet(decryptedBytes);
            return plaintextByteBuffer;
        });
    }

    private Long blobSize(HttpHeaders headers) {
        if (headers.getValue(HttpHeaderName.CONTENT_RANGE) != null) {
            String range = headers.getValue(HttpHeaderName.CONTENT_RANGE);
            return Long.valueOf(range.split("/")[1]);
        }
        return Long.valueOf(headers.getValue(HttpHeaderName.CONTENT_LENGTH));
    }

    static class DecryptedResponse
    extends HttpResponse {
        private final Flux<ByteBuffer> plainTextBody;
        private final HttpHeaders httpHeaders;
        private final int statusCode;

        DecryptedResponse(HttpResponse httpResponse, Flux<ByteBuffer> plainTextBody) {
            super(httpResponse.getRequest());
            this.plainTextBody = plainTextBody;
            this.httpHeaders = httpResponse.getHeaders();
            this.statusCode = httpResponse.getStatusCode();
        }

        public int getStatusCode() {
            return this.statusCode;
        }

        public String getHeaderValue(String name) {
            return this.httpHeaders.getValue(name);
        }

        public HttpHeaders getHeaders() {
            return this.httpHeaders;
        }

        public Flux<ByteBuffer> getBody() {
            return this.plainTextBody;
        }

        public Mono<byte[]> getBodyAsByteArray() {
            return FluxUtil.collectBytesInByteBufferStream(this.plainTextBody);
        }

        public Mono<String> getBodyAsString() {
            return FluxUtil.collectBytesInByteBufferStream(this.plainTextBody).map(String::new);
        }

        public Mono<String> getBodyAsString(Charset charset) {
            return FluxUtil.collectBytesInByteBufferStream(this.plainTextBody).map(b -> new String((byte[])b, charset));
        }
    }
}

