package com.atlassian.httpclient.apache.httpcomponents;

import java.io.IOException;

import org.apache.http.ContentTooLongException;
import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.entity.ContentType;
import org.apache.http.nio.ContentDecoder;
import org.apache.http.nio.IOControl;
import org.apache.http.nio.entity.ContentBufferEntity;
import org.apache.http.nio.protocol.AbstractAsyncResponseConsumer;
import org.apache.http.nio.util.ByteBufferAllocator;
import org.apache.http.nio.util.HeapByteBufferAllocator;
import org.apache.http.nio.util.SimpleInputBuffer;
import org.apache.http.protocol.HttpContext;
import org.apache.http.util.Asserts;
import com.google.common.primitives.Ints;

/**
 * An AsyncResponseConsumer that buffers input until the buffer contains {@code maxEntitySize} bytes. If more data
 * is read, a {@link ContentTooLongException} is thrown.
 *
 * @since 0.23.5
 */
public class BoundedAsyncResponseConsumer extends AbstractAsyncResponseConsumer<HttpResponse> {

    // limit the amount of memory that is pre-allocated based on the reported Content-Length. Let's be a bit paranoid
    private static final int MAX_INITIAL_BUFFER_SIZE = 256 * 1024;

    private final int maxEntitySize;

    private volatile BoundedInputBuffer contentBuffer;
    private volatile DiscardBuffer discardBuffer;
    private volatile HttpResponse response;

    BoundedAsyncResponseConsumer(int maxEntitySize) {
        this.maxEntitySize = maxEntitySize;
    }

    protected HttpResponse buildResult(HttpContext context) {
        return response;
    }

    protected void onContentReceived(ContentDecoder decoder, IOControl ioctrl) throws IOException {
        Asserts.notNull(contentBuffer, "Content buffer");

        if (contentBuffer.isDiscardMode()) {
            discardRemainingContent(decoder);
        } else {
            try {
                contentBuffer.consumeContent(decoder);
            } catch (BufferFullException e) {
                discardRemainingContent(decoder);
            }
        }
    }

    protected void onEntityEnclosed(HttpEntity entity, ContentType contentType) throws IOException {
        int length = Math.min(Ints.saturatedCast(entity.getContentLength()), maxEntitySize);
        if (length < 0L) {
            // start with a 4k buffer
            length = Math.min(4096, maxEntitySize);
        }
        int initialBufferSize = Math.min(MAX_INITIAL_BUFFER_SIZE, length);

        contentBuffer = new BoundedInputBuffer(initialBufferSize, maxEntitySize, new HeapByteBufferAllocator());
        discardBuffer = new DiscardBuffer();

        Asserts.notNull(response, "response");
        response.setEntity(new ContentBufferEntity(entity, contentBuffer));
    }

    protected void onResponseReceived(HttpResponse response) throws IOException {
        this.response = response;
    }

    protected void releaseResources() {
        this.response = null;
        this.contentBuffer = null;
        this.discardBuffer = null;
    }

    private void discardRemainingContent(ContentDecoder decoder) throws IOException {
        // We're done with the content buffer, so we manually set it to endofstream
        contentBuffer.shutdown();

        // Discarding remaining content
        discardBuffer.consumeContent(decoder);

        // Throw EntityTooLargeException when the decoder has been fully read.
        // Otherwise, do nothing to allow this consumer to continue receiving more chunks.
        if (decoder.isCompleted()) {
            throw new EntityTooLargeException(
                    response, "Entity content is too long; larger than " + maxEntitySize + " bytes");
        }
    }

    private static class BoundedInputBuffer extends SimpleInputBuffer {

        private final int maxSize;

        private boolean discardMode;

        BoundedInputBuffer(int initialSize, int maxSize, ByteBufferAllocator allocator) {
            super(Math.min(maxSize, initialSize), allocator);

            this.maxSize = maxSize;
        }

        public boolean isDiscardMode() {
            return discardMode;
        }

        @Override
        protected void expand() {
            int capacity = buffer.capacity();
            int newCapacity = capacity < 2 ? 2 : capacity + (capacity >>> 1);
            if (newCapacity < capacity) {
                // must be integer overflow
                newCapacity = Integer.MAX_VALUE;
            }
            ensureCapacity(newCapacity);
        }

        @Override
        protected void ensureCapacity(int requiredCapacity) {
            if (buffer.capacity() == maxSize && requiredCapacity > maxSize) {
                // Switch to discard mode if required capacity exceeds the maximum size
                discardMode = true;

                // We still need to throw so that the consumeContent loop is broken
                throw new BufferFullException();
            }
            super.ensureCapacity(Math.min(requiredCapacity, maxSize));
        }
    }

    /**
     * A custom input buffer used to discard excess content when the main buffer
     * has reached its maximum capacity. This buffer reads incoming data to allow
     * the session to complete but does not retain the data.
     * <p>
     * When the {@link BoundedInputBuffer} enters discard mode due to exceeding
     * the maximum entity size, the {@link DiscardBuffer} is used to consume and
     * discard the remaining content from the {@link ContentDecoder}. This ensures
     * that the session can properly complete without consuming additional memory
     * for the discarded content.
     */
    private static class DiscardBuffer extends SimpleInputBuffer {

        /**
         * Constructs a new DiscardBuffer with a default size of 4 KB.
         * This size is used to temporarily hold data while it is discarded.
         */
        public DiscardBuffer() {
            super(4096);
        }

        /**
         * Overrides the expand method to discard contents by clearing the buffer.
         * This method is called when the buffer's capacity is exceeded.
         */
        @Override
        protected void expand() {
            buffer.clear();
        }
    }

    private static class BufferFullException extends RuntimeException {}
}
