package io.embrace.android.embracesdk.network.http;

import com.fernandocejas.arrow.checks.Preconditions;

import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;

import java9.util.function.Consumer;

/**
 * Counts the bytes read from an input stream and invokes a callback once the stream has reached
 * the end.
 */
final class CountingInputStreamWithCallback extends FilterInputStream {
    /**
     * The mark.
     */
    private volatile long streamMark = -1;
    /**
     * The callback to be invoked with num of bytes after reaching the end of the stream.
     */
    private final Consumer<Long> callback;

    /**
     * The callback to be invoked with the captured body.
     */
    private final Consumer<byte[]> capturedBodyCallback;

    /**
     * true if the callback has been invoked, false otherwise.
     */
    private volatile boolean callbackCompleted;

    /**
     * The count of the number of bytes which have been read.
     */
    private long count;

    private volatile boolean shouldCaptureBody;

    /**
     * Wraps another input stream, counting the number of bytes read.
     *
     * @param in the input stream to be wrapped
     */
    CountingInputStreamWithCallback(InputStream in, boolean shouldCaptureBody, Consumer<Long> callback, Consumer<byte[]> capturedBodyCallback) {
        super(Preconditions.checkNotNull(in));
        this.callback = Preconditions.checkNotNull(callback);
        this.capturedBodyCallback = Preconditions.checkNotNull(capturedBodyCallback);
        this.shouldCaptureBody = shouldCaptureBody;
    }

    /**
     * Returns the number of bytes read.
     */
    long getCount() {
        return count;
    }

    @Override
    public int read() throws IOException {
        int result = in.read();
        if (result != -1) {
            count++;
        } else if (!callbackCompleted) {
            callbackCompleted = true;
            callback.accept(count);
        }

        return result;
    }

    @Override
    public int read(byte[] b) throws IOException {
        int result = super.read(b);
        if (result != -1) {
            conditionallyCaptureBody(b);
        }
        return result;
    }

    @Override
    public int read(byte[] b, int off, int len) throws IOException {
        int result = in.read(b, off, len);

        if (result != -1) {
            count += result;
            conditionallyCaptureBody(b);
        } else if (!callbackCompleted) {
            callbackCompleted = true;
            callback.accept(count);
        }

        return result;
    }

    @Override
    public long skip(long n) throws IOException {
        long result = in.skip(n);
        count += result;
        return result;
    }

    @Override
    public synchronized void mark(int readlimit) {
        in.mark(readlimit);
        streamMark = count;
        // it's okay to streamMark even if streamMark isn't supported, as reset won't work
    }

    @Override
    public synchronized void reset() throws IOException {
        if (!in.markSupported()) {
            throw new IOException("Mark not supported");
        }
        if (streamMark == -1) {
            throw new IOException("Mark not set");
        }

        in.reset();
        count = streamMark;
        callbackCompleted = false;
    }

    private synchronized void conditionallyCaptureBody(byte[] body) {

        if (!shouldCaptureBody) {
            return;
        }

        capturedBodyCallback.accept(body);
    }
}
