package io.embrace.android.embracesdk.network.http;

import android.annotation.TargetApi;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.ProtocolException;
import java.net.URL;
import java.security.Permission;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;

import io.embrace.android.embracesdk.Embrace;
import io.embrace.android.embracesdk.EmbraceLogger;
import io.embrace.android.embracesdk.utils.exceptions.Unchecked;

/**
 * Wraps @{link HttpUrlConnection} to log network calls to Embrace. The wrapper also wraps the
 * InputStream to get an accurate count of bytes received if a Content-Length is not provided by
 * the server.
 * <p>
 * The wrapper handles gzip decompression itself and strips the {@code Content-Length} and
 * {@code Content-Encoding} headers. Typically this is handled transparently by
 * {@link HttpURLConnection} but would prevent us from accessing the {@code Content-Length}.
 * <p>
 * Network logging currently does not follow redirects. The duration is logged from initiation of
 * the network call (upon invocation of any method which would initiate the network call), to the
 * retrieval of the response.
 * <p>
 * As network calls are initiated lazily, we log the network call prior to the calling of any
 * wrapped method which would result in the network call actually being executed, and store a
 * flag to prevent duplication of calls.
 */
public class EmbraceHttpUrlConnection<T extends HttpURLConnection> extends HttpURLConnection {

    /**
     * The content encoding HTTP header.
     */
    private static final String CONTENT_ENCODING = "Content-Encoding";

    /**
     * The content length HTTP header.
     */
    private static final String CONTENT_LENGTH = "Content-Length";

    /**
     * Reference to the wrapped connection.
     */
    private final T connection;

    /**
     * The time at which the connection was created.
     */
    private final long createdTime;

    /**
     * Whether transparent gzip compression is enabled.
     */
    private final boolean enableTransparentGzip;

    /**
     * A reference to the output stream wrapped in a counter, so we can determine the bytes sent.
     */
    private volatile CountingOutputStream outputStream;

    /**
     * Whether the network call has already been logged, to prevent duplication.
     */
    private volatile boolean didLogNetworkCall = false;

    /**
     * The time at which the network call ended.
     */
    private volatile Long endTime;

    /**
     * The time at which the network call was initiated.
     */
    private volatile Long startTime;

    /**
     * The trace id specified for the request
     */
    private volatile String traceId;

    /**
     * Wraps an existing {@link HttpURLConnection} with the Embrace network logic.
     *
     * @param connection            the connection to wrap
     * @param enableTransparentGzip true if we should transparently ungzip the response, else false
     */
    public EmbraceHttpUrlConnection(T connection, boolean enableTransparentGzip) {
        super(connection.getURL());

        this.connection = connection;
        this.createdTime = System.currentTimeMillis();
        this.enableTransparentGzip = enableTransparentGzip;
    }

    @Override
    public void addRequestProperty(String key, String value) {
        this.connection.addRequestProperty(key, value);
    }

    @Override
    public void connect() throws IOException {
        identifyTraceId();
        this.connection.connect();
    }

    @Override
    public void disconnect() {
        identifyTraceId();
        // The network call must be logged before we close the transport
        logNetworkCall(this.createdTime);
        this.connection.disconnect();
    }

    @Override
    public boolean getAllowUserInteraction() {
        return this.connection.getAllowUserInteraction();
    }

    @Override
    public void setAllowUserInteraction(boolean allowUserInteraction) {
        this.connection.setAllowUserInteraction(allowUserInteraction);
    }

    @Override
    public int getConnectTimeout() {
        return this.connection.getConnectTimeout();
    }

    @Override
    public void setConnectTimeout(int timeout) {
        this.connection.setConnectTimeout(timeout);
    }

    @Override
    public Object getContent() throws IOException {
        identifyTraceId();
        return this.connection.getContent();
    }

    @Override
    public Object getContent(Class[] classes) throws IOException {
        identifyTraceId();
        return this.connection.getContent(classes);
    }

    @Override
    public String getContentEncoding() {
        return shouldUncompressGzip() ? null : this.connection.getContentEncoding();
    }

    @Override
    public int getContentLength() {
        return shouldUncompressGzip() ? -1 : this.connection.getContentLength();
    }

    @Override
    @TargetApi(24)
    public long getContentLengthLong() {
        return shouldUncompressGzip() ? -1 : this.connection.getContentLengthLong();
    }

    @Override
    public String getContentType() {
        return this.connection.getContentType();
    }

    @Override
    public long getDate() {
        return this.connection.getDate();
    }

    @Override
    public boolean getDefaultUseCaches() {
        return this.connection.getDefaultUseCaches();
    }

    @Override
    public void setDefaultUseCaches(boolean defaultUseCaches) {
        this.connection.setDefaultUseCaches(defaultUseCaches);
    }

    @Override
    public boolean getDoInput() {
        return this.connection.getDoInput();
    }

    @Override
    public void setDoInput(boolean doInput) {
        this.connection.setDoInput(doInput);
    }

    @Override
    public boolean getDoOutput() {
        return this.connection.getDoOutput();
    }

    @Override
    public void setDoOutput(boolean doOutput) {
        this.connection.setDoOutput(doOutput);
    }

    @Override
    public InputStream getErrorStream() {
        identifyTraceId();
        long startTime = System.currentTimeMillis();
        final InputStream parentStream;
        if (this.connection.getContentLength() < 0) {
            parentStream = countingInputStream(this.connection.getErrorStream());
        } else {
            parentStream = this.connection.getErrorStream();
        }
        InputStream in;
        if (parentStream != null) {
            if (shouldUncompressGzip()) {
                try {
                    in = Unchecked.wrap(() -> new GZIPInputStream(parentStream));
                } catch (Exception e) {
                    // This handles the case where it's availability is 0, causing the GZIPInputStream instantiation to fail.
                    in = parentStream;
                }
            } else {
                in = parentStream;
            }
        } else {
            // The response body is empty, so don't attempt to wrap it with anything
            in = null;
        }
        logNetworkCall(startTime);
        return in;
    }

    @Override
    public long getExpiration() {
        return this.connection.getExpiration();
    }

    @Override
    public String getHeaderField(int n) {
        long startTime = System.currentTimeMillis();
        String key = this.connection.getHeaderFieldKey(n);
        if (shouldUncompressGzip() &&
                (key.equals(CONTENT_ENCODING) || key.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return null;
        }
        String headerField = this.connection.getHeaderField(n);
        logNetworkCall(startTime);
        return headerField;
    }

    @Override
    public String getHeaderField(String name) {
        long startTime = System.currentTimeMillis();
        if (shouldUncompressGzip() &&
                (name.equals(CONTENT_ENCODING) || name.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return null;
        }
        String headerField = this.connection.getHeaderField(name);

        logNetworkCall(startTime);
        return headerField;
    }

    @Override
    public long getHeaderFieldDate(String name, long defaultValue) {
        long startTime = System.currentTimeMillis();
        if (shouldUncompressGzip() &&
                (name.equals(CONTENT_ENCODING) || name.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return defaultValue;
        }
        long result = this.connection.getHeaderFieldDate(name, defaultValue);
        logNetworkCall(startTime);
        return result;
    }

    @Override
    public int getHeaderFieldInt(String name, int defaultValue) {
        long startTime = System.currentTimeMillis();
        if (shouldUncompressGzip() &&
                (name.equals(CONTENT_ENCODING) || name.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return defaultValue;
        }
        int result = this.connection.getHeaderFieldInt(name, defaultValue);
        logNetworkCall(startTime);
        return result;
    }

    @Override
    public String getHeaderFieldKey(int n) {
        long startTime = System.currentTimeMillis();
        String headerFieldKey = this.connection.getHeaderFieldKey(n);
        if (shouldUncompressGzip() &&
                (headerFieldKey.equals(CONTENT_ENCODING) || headerFieldKey.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return null;
        }
        logNetworkCall(startTime);
        return headerFieldKey;
    }

    @Override
    @TargetApi(24)
    public long getHeaderFieldLong(String name, long defaultValue) {
        long startTime = System.currentTimeMillis();
        if (shouldUncompressGzip() &&
                (name.equals(CONTENT_ENCODING) || name.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return defaultValue;
        }
        long result = this.connection.getHeaderFieldLong(name, defaultValue);
        logNetworkCall(startTime);
        return result;
    }

    @Override
    public Map<String, List<String>> getHeaderFields() {
        long startTime = System.currentTimeMillis();
        Map<String, List<String>> headerFields = this.connection.getHeaderFields();
        HashMap<String, List<String>> copy = new HashMap<>(headerFields);
        // Strip the content encoding and length headers, as we transparently ungzip the content
        copy.remove(CONTENT_ENCODING);
        copy.remove(CONTENT_LENGTH);
        logNetworkCall(startTime);
        return copy;
    }

    @Override
    public long getIfModifiedSince() {
        return this.connection.getIfModifiedSince();
    }

    @Override
    public void setIfModifiedSince(long ifModifiedSince) {
        this.connection.setIfModifiedSince(ifModifiedSince);
    }

    @Override
    public InputStream getInputStream() throws IOException {
        identifyTraceId();
        long startTime = System.currentTimeMillis();
        final InputStream parentStream;
        if (this.connection.getContentLength() < 0) {
            parentStream = countingInputStream(this.connection.getInputStream());
        } else {
            parentStream = this.connection.getInputStream();
        }
        InputStream in;
        if (parentStream != null) {
            if (shouldUncompressGzip()) {
                try {
                    in = Unchecked.wrap(() -> new GZIPInputStream(parentStream));
                } catch (Exception e) {
                    // This handles the case where it's availability is 0, causing the GZIPInputStream instantiation to fail.
                    in = parentStream;
                }
            } else {
                in = parentStream;
            }
        } else {
            // The response body is empty, so don't attempt to wrap it with anything
            in = null;
        }
        logNetworkCall(startTime);
        return in;
    }

    @Override
    public boolean getInstanceFollowRedirects() {
        return this.connection.getInstanceFollowRedirects();
    }

    @Override
    public void setInstanceFollowRedirects(boolean followRedirects) {
        this.connection.setInstanceFollowRedirects(followRedirects);
    }

    @Override
    public long getLastModified() {
        return this.connection.getLastModified();
    }

    @Override
    public OutputStream getOutputStream() throws IOException {
        identifyTraceId();
        if (this.outputStream == null && connection.getOutputStream() != null) {
            this.outputStream = new CountingOutputStream(connection.getOutputStream());
        }
        return this.outputStream;
    }

    @Override
    public Permission getPermission() throws IOException {
        return this.connection.getPermission();
    }

    @Override
    public int getReadTimeout() {
        return this.connection.getReadTimeout();
    }

    @Override
    public void setReadTimeout(int timeout) {
        this.connection.setReadTimeout(timeout);
    }

    @Override
    public String getRequestMethod() {
        return this.connection.getRequestMethod();
    }

    @Override
    public void setRequestMethod(String method) throws ProtocolException {
        this.connection.setRequestMethod(method);
    }

    @Override
    public Map<String, List<String>> getRequestProperties() {
        return this.connection.getRequestProperties();
    }

    @Override
    public String getRequestProperty(String key) {
        return this.connection.getRequestProperty(key);
    }

    @Override
    public int getResponseCode() throws IOException {
        identifyTraceId();
        long startTime = System.currentTimeMillis();
        int responseCode = this.connection.getResponseCode();
        logNetworkCall(startTime);
        return responseCode;
    }

    @Override
    public String getResponseMessage() throws IOException {
        identifyTraceId();
        long startTime = System.currentTimeMillis();
        String responseMsg = this.connection.getResponseMessage();
        logNetworkCall(startTime);
        return responseMsg;
    }

    @Override
    public URL getURL() {
        return this.connection.getURL();
    }

    @Override
    public boolean getUseCaches() {
        return this.connection.getUseCaches();
    }

    @Override
    public void setUseCaches(boolean useCaches) {
        this.connection.setUseCaches(useCaches);
    }

    @Override
    public void setChunkedStreamingMode(int chunkLen) {
        this.connection.setChunkedStreamingMode(chunkLen);
    }

    @Override
    public void setFixedLengthStreamingMode(int contentLength) {
        this.connection.setFixedLengthStreamingMode(contentLength);
    }

    @Override
    @TargetApi(19)
    public void setFixedLengthStreamingMode(long contentLength) {
        this.connection.setFixedLengthStreamingMode(contentLength);
    }

    @Override
    public void setRequestProperty(String key, String value) {
        this.connection.setRequestProperty(key, value);
    }

    @Override
    public String toString() {
        return this.connection.toString();
    }

    @Override
    public boolean usingProxy() {
        return this.connection.usingProxy();
    }

    /**
     * Given a start time (in milliseconds), logs the network call to Embrace using the current time as the end time.
     * <p>
     * If the network call has already been logged for this HttpURLConnection, this method is a no-op and is effectively
     * ignored.
     */
    synchronized void logNetworkCall(long startTime) {
        logNetworkCall(startTime, System.currentTimeMillis(), false, null);
    }

    /**
     * Given a start time and end time (in milliseconds), logs the network call to Embrace.
     * <p>
     * If the network call has already been logged for this HttpURLConnection, this method is a no-op and is effectively
     * ignored.
     */
    synchronized void logNetworkCall(long startTime, long endTime, boolean overwrite, Long bytesIn) {
        if (!this.didLogNetworkCall || overwrite) {
            // We are proactive with setting this flag so that we don't get nested calls to log the network call by virtue of
            // extracting the data we need to log the network call.
            this.didLogNetworkCall = true;
            this.startTime = startTime;
            this.endTime = endTime;

            String url = EmbraceHttpPathOverride.getURLString(new EmbraceHttpUrlConnectionOverride(this.connection));

            try {
                long bytesOut = this.outputStream == null ? 0 : Math.max(this.outputStream.getCount(),
                        0);
                long contentLength = bytesIn == null ? Math.max(this.connection.getContentLength(), 0) : bytesIn;
                Embrace.getInstance().logNetworkCall(
                        url,
                        HttpMethod.fromString(getRequestMethod()),
                        getResponseCode(),
                        startTime,
                        endTime,
                        bytesOut,
                        contentLength,
                        traceId
                );
            } catch (Exception e) {
                Embrace.getInstance().logNetworkClientError(
                        url,
                        HttpMethod.fromString(getRequestMethod()),
                        startTime,
                        endTime,
                        e.getClass().getCanonicalName(),
                        e.getMessage(),
                        traceId
                );
            }
        }
    }

    /**
     * Wraps an input stream with an input stream which counts the number of bytes read, and then
     * updates the network call service with the correct number of bytes read once the stream has
     * reached the end.
     *
     * @param inputStream the input stream to count
     * @return the wrapped input stream
     */
    private InputStream countingInputStream(InputStream inputStream) {
        return new CountingInputStreamWithCallback(inputStream, bytes -> {
            if (EmbraceHttpUrlConnection.this.startTime != null && EmbraceHttpUrlConnection.this.endTime != null) {
                logNetworkCall(
                        EmbraceHttpUrlConnection.this.startTime,
                        EmbraceHttpUrlConnection.this.endTime,
                        true,
                        bytes);
            }
        });
    }

    /**
     * We disable the automatic gzip decompression behavior of {@link HttpURLConnection} in the
     * {@link EmbraceHttpUrlStreamHandler} to ensure that we can count the bytes in the response
     * from the server. We decompress the response transparently to the user only if both:
     * <ul>
     * <li>The user did not specify an encoding</li>
     * <li>The server returned a gzipped response</li>
     * </ul>
     * <p>
     * If the user specified an encoding, even if it is gzip, we do not transparently decompress
     * the response. This is to mirror the behavior of {@link HttpURLConnection} whilst providing
     * us access to the content length.
     *
     * @return true if we should decompress the response, false otherwise
     * @see <a href="https://developer.android.com/reference/java/net/HttpURLConnection#performance">Android Docs</a>
     * @see <a href="https://android.googlesource.com/platform/external/okhttp/+/master/okhttp/src/main/java/com/squareup/okhttp/internal/http/HttpEngine.java">Android Source Code</a>
     */
    private boolean shouldUncompressGzip() {
        String contentEncoding = this.connection.getContentEncoding();
        return enableTransparentGzip &&
                contentEncoding != null &&
                contentEncoding.equalsIgnoreCase("gzip");
    }

    private void identifyTraceId() {
        if (traceId == null){
            try {
                traceId = getRequestProperty(Embrace.getInstance().getTraceIdHeader());
            } catch (Exception e) {
                EmbraceLogger.logDebug("Failed to retrieve actual trace id header. Current: "+ traceId);
            }
        }
    }
}
