package io.embrace.android.embracesdk.network.http;

import android.annotation.TargetApi;
import android.os.Build;

import androidx.annotation.Nullable;

import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.ProtocolException;
import java.net.URL;
import java.security.Permission;
import java.security.cert.Certificate;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocketFactory;

import io.embrace.android.embracesdk.Embrace;
import io.embrace.android.embracesdk.EmbraceLogger;
import io.embrace.android.embracesdk.network.NetworkCaptureData;
import io.embrace.android.embracesdk.utils.exceptions.Unchecked;
import java9.util.stream.StreamSupport;

class EmbraceUrlConnectionOverride<T extends HttpURLConnection> implements EmbraceUrlConnectionService, EmbraceSslUrlConnectionService {

    /**
     * The content encoding HTTP header.
     */
    private static final String CONTENT_ENCODING = "Content-Encoding";

    /**
     * The content length HTTP header.
     */
    private static final String CONTENT_LENGTH = "Content-Length";

    /**
     * Reference to the wrapped connection.
     */
    private final T connection;

    /**
     * The time at which the connection was created.
     */
    private final long createdTime;

    /**
     * Whether transparent gzip compression is enabled.
     */
    private final boolean enableTransparentGzip;

    /**
     * A reference to the input stream wrapped in a counter, so we can determine the bytes received.
     */
    private volatile CountingInputStreamWithCallback inputStream;

    /**
     * A reference to the output stream wrapped in a counter, so we can determine the bytes sent.
     */
    private volatile CountingOutputStream outputStream;

    /**
     * Whether the network call has already been logged, to prevent duplication.
     */
    private volatile boolean didLogNetworkCall = false;

    /**
     * The time at which the network call ended.
     */
    private volatile Long endTime;

    /**
     * The time at which the network call was initiated.
     */
    private volatile Long startTime;

    /**
     * The trace id specified for the request
     */
    private volatile String traceId;

    /**
     * The request headers captured from the http connection.
     */
    private volatile HashMap<String, String> requestHeaders;

    /**
     * The response body captured from the http connection.
     */
    private volatile byte[] capturedResponseBody;

    /**
     * Wraps an existing {@link HttpURLConnection} or {@link HttpsURLConnection} with the Embrace network logic.
     *
     * @param connection            the connection to wrap
     * @param enableTransparentGzip true if we should transparently ungzip the response, else false
     */
    EmbraceUrlConnectionOverride(T connection, boolean enableTransparentGzip) {
        this.connection = connection;
        this.createdTime = System.currentTimeMillis();
        this.enableTransparentGzip = enableTransparentGzip;
    }

    @Override
    public void addRequestProperty(String key, String value) {
        this.connection.addRequestProperty(key, value);
    }

    @Override
    public void connect() throws IOException {
        identifyTraceId();
        this.connection.connect();
    }

    @Override
    public void disconnect() {
        identifyTraceId();
        // The network call must be logged before we close the transport
        logNetworkCall(this.createdTime, false);
        this.connection.disconnect();
    }

    @Override
    public boolean getAllowUserInteraction() {
        return this.connection.getAllowUserInteraction();
    }

    @Override
    public void setAllowUserInteraction(boolean allowUserInteraction) {
        this.connection.setAllowUserInteraction(allowUserInteraction);
    }

    @Override
    public int getConnectTimeout() {
        return this.connection.getConnectTimeout();
    }

    @Override
    public void setConnectTimeout(int timeout) {
        this.connection.setConnectTimeout(timeout);
    }

    @Override
    public Object getContent() throws IOException {
        identifyTraceId();
        return this.connection.getContent();
    }

    @Override
    public Object getContent(Class[] classes) throws IOException {
        identifyTraceId();
        return this.connection.getContent(classes);
    }

    @Override
    public String getContentEncoding() {
        return shouldUncompressGzip() ? null : this.connection.getContentEncoding();
    }

    @Override
    public int getContentLength() {
        return shouldUncompressGzip() ? -1 : this.connection.getContentLength();
    }

    @Override
    @TargetApi(24)
    public long getContentLengthLong() {
        return (shouldUncompressGzip() || Build.VERSION.SDK_INT < Build.VERSION_CODES.N) ?
                -1 : this.connection.getContentLengthLong();
    }

    @Override
    public String getContentType() {
        return this.connection.getContentType();
    }

    @Override
    public long getDate() {
        return this.connection.getDate();
    }

    @Override
    public boolean getDefaultUseCaches() {
        return this.connection.getDefaultUseCaches();
    }

    @Override
    public void setDefaultUseCaches(boolean defaultUseCaches) {
        this.connection.setDefaultUseCaches(defaultUseCaches);
    }

    @Override
    public boolean getDoInput() {
        return this.connection.getDoInput();
    }

    @Override
    public void setDoInput(boolean doInput) {
        this.connection.setDoInput(doInput);
    }

    @Override
    public boolean getDoOutput() {
        return this.connection.getDoOutput();
    }

    @Override
    public void setDoOutput(boolean doOutput) {
        this.connection.setDoOutput(doOutput);
    }

    @Override
    public InputStream getErrorStream() {
        this.inputStream = countingInputStream(new BufferedInputStream(this.connection.getErrorStream()));
        return getWrappedInputStream();
    }

    @Override
    public long getExpiration() {
        return this.connection.getExpiration();
    }

    @Override
    public String getHeaderField(int n) {
        long startTime = System.currentTimeMillis();
        String key = this.connection.getHeaderFieldKey(n);
        if (key != null && shouldUncompressGzip() &&
                (key.equals(CONTENT_ENCODING) || key.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return null;
        }
        String headerField = this.connection.getHeaderField(n);
        logNetworkCall(startTime, false);
        return headerField;
    }

    @Override
    public String getHeaderField(String name) {

        if (name == null) {
            return null;
        }

        long startTime = System.currentTimeMillis();
        if (shouldUncompressGzip() &&
                (name.equals(CONTENT_ENCODING) || name.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return null;
        }
        String headerField = this.connection.getHeaderField(name);

        logNetworkCall(startTime, false);
        return headerField;
    }

    @Override
    public long getHeaderFieldDate(String name, long defaultValue) {
        long startTime = System.currentTimeMillis();
        if (name != null && shouldUncompressGzip() &&
                (name.equals(CONTENT_ENCODING) || name.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return defaultValue;
        }
        long result = this.connection.getHeaderFieldDate(name, defaultValue);
        logNetworkCall(startTime, false);
        return result;
    }

    @Override
    public int getHeaderFieldInt(String name, int defaultValue) {
        long startTime = System.currentTimeMillis();
        if (name != null && shouldUncompressGzip() &&
                (name.equals(CONTENT_ENCODING) || name.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return defaultValue;
        }
        int result = this.connection.getHeaderFieldInt(name, defaultValue);
        logNetworkCall(startTime, false);
        return result;
    }

    @Override
    public String getHeaderFieldKey(int n) {
        long startTime = System.currentTimeMillis();
        String headerFieldKey = this.connection.getHeaderFieldKey(n);
        if (headerFieldKey != null && shouldUncompressGzip() &&
                (headerFieldKey.equals(CONTENT_ENCODING) || headerFieldKey.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return null;
        }
        logNetworkCall(startTime, false);
        return headerFieldKey;
    }

    @Override
    @TargetApi(24)
    public long getHeaderFieldLong(String name, long defaultValue) {
        long startTime = System.currentTimeMillis();
        if (name != null && shouldUncompressGzip() &&
                (name.equals(CONTENT_ENCODING) || name.equals(CONTENT_LENGTH))) {
            // Strip the content encoding and length headers, as we transparently ungzip the content
            return defaultValue;
        }
        long result = Build.VERSION.SDK_INT < Build.VERSION_CODES.N ?
                -1 : this.connection.getHeaderFieldLong(name, defaultValue);
        logNetworkCall(startTime, false);
        return result;
    }

    @Override
    public Map<String, List<String>> getHeaderFields() {
        long startTime = System.currentTimeMillis();
        Map<String, List<String>> headerFields = this.connection.getHeaderFields();
        HashMap<String, List<String>> copy = new HashMap<>(headerFields);
        // Strip the content encoding and length headers, as we transparently ungzip the content
        copy.remove(CONTENT_ENCODING);
        copy.remove(CONTENT_LENGTH);
        logNetworkCall(startTime, false);
        return copy;
    }

    @Override
    public long getIfModifiedSince() {
        return this.connection.getIfModifiedSince();
    }

    @Override
    public void setIfModifiedSince(long ifModifiedSince) {
        this.connection.setIfModifiedSince(ifModifiedSince);
    }

    @Override
    public InputStream getInputStream() throws IOException {
        this.inputStream = countingInputStream(new BufferedInputStream(this.connection.getInputStream()));
        return getWrappedInputStream();
    }

    @Override
    public boolean getInstanceFollowRedirects() {
        return this.connection.getInstanceFollowRedirects();
    }

    @Override
    public void setInstanceFollowRedirects(boolean followRedirects) {
        this.connection.setInstanceFollowRedirects(followRedirects);
    }

    @Override
    public long getLastModified() {
        return this.connection.getLastModified();
    }

    @Override
    public OutputStream getOutputStream() throws IOException {
        identifyTraceId();
        if (this.outputStream == null && connection.getOutputStream() != null) {
            this.outputStream = new CountingOutputStream(connection.getOutputStream());
        }
        return this.outputStream;
    }

    @Override
    public Permission getPermission() throws IOException {
        return this.connection.getPermission();
    }

    @Override
    public int getReadTimeout() {
        return this.connection.getReadTimeout();
    }

    @Override
    public void setReadTimeout(int timeout) {
        this.connection.setReadTimeout(timeout);
    }

    @Override
    public String getRequestMethod() {
        return this.connection.getRequestMethod();
    }

    @Override
    public void setRequestMethod(String method) throws ProtocolException {
        this.connection.setRequestMethod(method);
    }

    @Override
    public Map<String, List<String>> getRequestProperties() {
        return this.connection.getRequestProperties();
    }

    @Override
    public String getRequestProperty(String key) {
        return this.connection.getRequestProperty(key);
    }

    @Override
    public int getResponseCode() throws IOException {
        identifyTraceId();
        long startTime = System.currentTimeMillis();
        logNetworkCall(startTime, false);
        return this.connection.getResponseCode();
    }

    @Override
    public String getResponseMessage() throws IOException {
        identifyTraceId();
        long startTime = System.currentTimeMillis();
        logNetworkCall(startTime, false);
        return this.connection.getResponseMessage();
    }

    @Override
    public URL getURL() {
        return this.connection.getURL();
    }

    @Override
    public boolean getUseCaches() {
        return this.connection.getUseCaches();
    }

    @Override
    public void setUseCaches(boolean useCaches) {
        this.connection.setUseCaches(useCaches);
    }

    @Override
    public void setChunkedStreamingMode(int chunkLen) {
        this.connection.setChunkedStreamingMode(chunkLen);
    }

    @Override
    public void setFixedLengthStreamingMode(int contentLength) {
        this.connection.setFixedLengthStreamingMode(contentLength);
    }

    @Override
    @TargetApi(19)
    public void setFixedLengthStreamingMode(long contentLength) {
        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.KITKAT) {
            this.connection.setFixedLengthStreamingMode(contentLength);
        }
    }

    @Override
    public void setRequestProperty(String key, String value) {
        this.connection.setRequestProperty(key, value);
        // If we don't have network capture rules, it's unnecessary to save these values
        if (hasNetworkCaptureRules()) {
            this.requestHeaders = getProcessedHeaders(getRequestProperties());
        }
    }

    @Override
    public String toString() {
        return this.connection.toString();
    }

    @Override
    public boolean usingProxy() {
        return this.connection.usingProxy();
    }

    @Override
    @Nullable
    public String getCipherSuite() {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getCipherSuite();
        }

        return null;
    }

    @Override
    public Certificate[] getLocalCertificates() {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getLocalCertificates();
        }

        return new Certificate[0];
    }

    @Override
    public Certificate[] getServerCertificates() throws SSLPeerUnverifiedException {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getServerCertificates();
        }

        return new Certificate[0];
    }

    @Override
    public SSLSocketFactory getSSLSocketFactory() {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getSSLSocketFactory();
        }

        return null;
    }

    @Override
    public void setSSLSocketFactory(SSLSocketFactory factory) {
        if (this.connection instanceof HttpsURLConnection) {
            ((HttpsURLConnection) this.connection).setSSLSocketFactory(factory);
        }
    }

    @Override
    public HostnameVerifier getHostnameVerifier() {
        if (this.connection instanceof HttpsURLConnection) {
            return ((HttpsURLConnection) this.connection).getHostnameVerifier();
        }

        return null;
    }

    @Override
    public void setHostnameVerifier(HostnameVerifier verifier) {
        if (this.connection instanceof HttpsURLConnection) {
            ((HttpsURLConnection) this.connection).setHostnameVerifier(verifier);
        }
    }

    /**
     * Given a start time (in milliseconds), logs the network call to Embrace using the current time as the end time.
     * <p>
     * If the network call has already been logged for this HttpURLConnection, this method is a no-op and is effectively
     * ignored.
     */
    synchronized void logNetworkCall(long startTime, boolean shouldCaptureBody) {
        logNetworkCall(startTime, System.currentTimeMillis(), false, null, shouldCaptureBody);
    }


    /**
     * Given a start time and end time (in milliseconds), logs the network call to Embrace.
     * <p>
     * If the network call has already been logged for this HttpURLConnection, this method is a no-op and is effectively
     * ignored.
     */
    synchronized void logNetworkCall(long startTime, long endTime, boolean overwrite, Long bytesIn, boolean shouldCaptureBody) {
        if (!this.didLogNetworkCall || overwrite) {
            // We are proactive with setting this flag so that we don't get nested calls to log the network call by virtue of
            // extracting the data we need to log the network call.
            this.didLogNetworkCall = true;
            this.startTime = startTime;
            this.endTime = endTime;

            identifyTraceId();
            String url = EmbraceHttpPathOverride.getURLString(new EmbraceHttpUrlConnectionOverride(this.connection));

            NetworkCaptureData networkCaptureData = null;

            // If we don't have network capture rules, it's unnecessary to save these values
            if (hasNetworkCaptureRules() && shouldCaptureBody) {
                networkCaptureData = new NetworkCaptureData()
                        .withRequestQueryParams(connection.getURL().getQuery());

                networkCaptureData.withRequestHeaders(this.requestHeaders);
                if (this.outputStream != null) {
                    networkCaptureData.withRequestBody(this.outputStream.getRequestBody());
                }

                if (this.inputStream != null) {
                    networkCaptureData
                            .withResponseHeaders(getProcessedHeaders(getHeaderFields()))
                            .withResponseBody(capturedResponseBody);
                }

                if (this.requestHeaders != null) {
                    networkCaptureData.withRequestHeaders(this.requestHeaders);
                }
            }

            try {
                long bytesOut = this.outputStream == null ? 0 : Math.max(this.outputStream.getCount(),
                        0);
                long contentLength = bytesIn == null ? Math.max(this.connection.getContentLength(), 0) : bytesIn;
                Embrace.getInstance().logNetworkCall(
                        url,
                        HttpMethod.fromString(getRequestMethod()),
                        getResponseCode(),
                        startTime,
                        endTime,
                        bytesOut,
                        contentLength,
                        traceId,
                        networkCaptureData
                );
            } catch (Exception e) {
                Embrace.getInstance().logNetworkClientError(
                        url,
                        HttpMethod.fromString(getRequestMethod()),
                        startTime,
                        endTime,
                        e.getClass().getCanonicalName(),
                        e.getMessage(),
                        traceId,
                        networkCaptureData
                );
            }
        }
    }

    /**
     * Wraps an input stream with an input stream which counts the number of bytes read, and then
     * updates the network call service with the correct number of bytes read once the stream has
     * reached the end.
     *
     * @param inputStream the input stream to count
     * @return the wrapped input stream
     */
    private CountingInputStreamWithCallback countingInputStream(InputStream inputStream) {
        return new CountingInputStreamWithCallback(inputStream, hasNetworkCaptureRules(), bytes -> {
            if (this.startTime != null && this.endTime != null) {
                logNetworkCall(
                        this.startTime,
                        this.endTime,
                        true,
                        bytes,
                        true);
            }
        },
                body -> capturedResponseBody = body);
    }

    /**
     * We disable the automatic gzip decompression behavior of {@link HttpURLConnection} or
     * {@link HttpsURLConnection} in the {@link EmbraceHttpUrlStreamHandler} to ensure that
     * we can count the bytes in the response from the server.
     * We decompress the response transparently to the user only if both:
     * <ul>
     * <li>The user did not specify an encoding</li>
     * <li>The server returned a gzipped response</li>
     * </ul>
     * <p>
     * If the user specified an encoding, even if it is gzip, we do not transparently decompress
     * the response. This is to mirror the behavior of {@link HttpURLConnection} whilst providing
     * us access to the content length.
     *
     * @return true if we should decompress the response, false otherwise
     * @see <a href="https://developer.android.com/reference/java/net/HttpURLConnection#performance">Android Docs</a>
     * @see <a href="https://android.googlesource.com/platform/external/okhttp/+/master/okhttp/src/main/java/com/squareup/okhttp/internal/http/HttpEngine.java">Android Source Code</a>
     */
    private boolean shouldUncompressGzip() {
        String contentEncoding = this.connection.getContentEncoding();
        return enableTransparentGzip &&
                contentEncoding != null &&
                contentEncoding.equalsIgnoreCase("gzip");
    }

    private void identifyTraceId() {
        if (traceId == null) {
            try {
                traceId = getRequestProperty(Embrace.getInstance().getTraceIdHeader());
            } catch (Exception e) {
                EmbraceLogger.logDebug("Failed to retrieve actual trace id header. Current: " + traceId);
            }
        }
    }

    private HashMap<String, String> getProcessedHeaders(Map<String, List<String>> properties) {
        HashMap<String, String> headers = new HashMap<>();

        StreamSupport.stream(properties.entrySet())
                .forEach(h -> {
                    StringBuilder builder = new StringBuilder();
                    for (String value : h.getValue()) {
                        if (value != null) {
                            builder.append(value);
                        }
                    }
                    headers.put(h.getKey(), builder.toString());
                });

        return headers;
    }

    private boolean hasNetworkCaptureRules() {

        if (this.connection.getURL() == null) {
            return false;
        }

        String url = this.connection.getURL().toString();
        String method = this.connection.getRequestMethod();

        return Embrace.getInstance().shouldCaptureNetworkBody(url, method);
    }

    private InputStream getWrappedInputStream() {
        long startTime = System.currentTimeMillis();

        identifyTraceId();

        InputStream in;
        if (this.inputStream != null) {
            if (shouldUncompressGzip()) {
                try {
                    in = Unchecked.wrap(() -> new GZIPInputStream(this.inputStream));
                } catch (Exception e) {
                    // This handles the case where it's availability is 0, causing the GZIPInputStream instantiation to fail.
                    in = this.inputStream;
                }
            } else {
                in = this.inputStream;
            }
        } else {
            // The response body is empty, so don't attempt to wrap it with anything
            in = null;
        }

        logNetworkCall(startTime, false);

        return in;
    }
}
