package com.instabug.apm.networkinterception;

import static com.instabug.apm.constants.ErrorMessages.NETWORK_REQUEST_STARTED;
import static com.instabug.apm.networkinterception.utils.UrlConnectionHeaderUtilsKt.injectExternalNetworkTraceIdHeaderIfPossible;

import android.annotation.SuppressLint;
import android.os.Build;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.RequiresApi;

import com.instabug.apm.di.InterceptorsServiceLocator;
import com.instabug.apm.di.ServiceLocator;
import com.instabug.apm.logger.internal.Logger;
import com.instabug.apm.networkinterception.utils.APMCountableInputStream;
import com.instabug.apm.networkinterception.utils.APMCountableOutputStream;
import com.instabug.library.networkv2.BodyBufferHelper;
import com.instabug.library.util.ObjectMapper;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.ProtocolException;
import java.net.URL;
import java.security.Permission;
import java.security.Principal;
import java.security.SecureRandom;
import java.security.cert.Certificate;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLPeerUnverifiedException;

/**
 * Wrapper for HttpUrlConnection to log network calls made using {@link HttpsURLConnection}
 */
public class APMHttpsUrlConnection extends HttpsURLConnection implements APMCountableInputStream.Callback {
    private static final String CONTENT_TYPE = "content-type";
    private final Logger logger = ServiceLocator.getApmLogger();
    private final HttpsURLConnection connection;
    private final long startTimeNano;
    private final HashMap<String, String> requestHeaders = new HashMap<>();
    private final APMNetworkLogWrapper networkLogWrapper = new APMNetworkLogWrapper();
    private long startTime;
    @Nullable
    private APMCountableOutputStream countableOutputStream;
    @Nullable
    private APMCountableInputStream apmCountableInputStream;

    APMHttpsUrlConnection(HttpsURLConnection connection) {
        super(connection.getURL());
        this.connection = connection;
        startTime = System.currentTimeMillis() * 1000;
        startTimeNano = System.nanoTime();
        networkLogWrapper.setUrl(connection.getURL().toString());
        injectExternalNetworkTraceIdHeaderIfPossible(connection,networkLogWrapper);
    }

    @Override
    public void connect() throws IOException {
        startTime = System.currentTimeMillis() * 1000;
        logger.d(NETWORK_REQUEST_STARTED
                .replace("$method", connection.getRequestMethod())
                .replace("$url", connection.getURL().toString()));
        logNetworkCall(null);
        try {
            connection.connect();
        } catch (Exception e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        } finally {
        }
    }

    private void logNetworkCall(@Nullable Exception withException) {
        logNetworkCall(startTime, null, withException);
    }

    private void logNetworkCall(
            long startTime,
            @Nullable Long responseSize,
            @Nullable Exception withException
    ) {

        if (countableOutputStream == null) {
            networkLogWrapper.setRequestBodySize(0);
        } else {
            networkLogWrapper.setRequestBodySize(countableOutputStream.getCount());
        }
        if (responseSize != null) {
            networkLogWrapper.setResponseBodySize(responseSize);
        }

        networkLogWrapper.setStartTime(startTime);
        long durationMillis = getDurationMillis(startTimeNano);
        networkLogWrapper.setTotalDuration(durationMillis);
        networkLogWrapper.setRequestHeaders(ObjectMapper.toJson(requestHeaders).toString());
        networkLogWrapper.setRequestBody(getRequestBody());
        networkLogWrapper.setResponseBody(getResponseBody());
        if (networkLogWrapper.getResponseCode() > 0) {
            networkLogWrapper.setErrorMessage(null);
        }
        networkLogWrapper.insert(withException, InterceptorsServiceLocator.getHttpUrlConnectionSanitizer());
    }

    @Override
    public void addRequestProperty(@Nullable String key, @Nullable String value) {
        if (key != null) {
            requestHeaders.put(key, value);
            if (key.equalsIgnoreCase(CONTENT_TYPE)) {
                if (countableOutputStream != null) {
                    countableOutputStream.setDisableBodyBuffer(BodyBufferHelper.isMultipartType(value));
                }
                networkLogWrapper.setRequestContentType(value);
            }
            if (value != null) {
                this.connection.addRequestProperty(key, value);
            }
        }
    }

    @Nullable
    private String getResponseBody() {
        if (apmCountableInputStream != null) {
            return apmCountableInputStream.getBody();
        }
        return null;
    }

    @Nullable
    private String getRequestBody() {
        return null;
    }

    private long getDurationMillis(long sinceNanos) {
        return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - sinceNanos);
    }

    @Override
    public String getCipherSuite() {
        return connection.getCipherSuite();
    }

    @Override
    public boolean usingProxy() {
        return connection.usingProxy();
    }

    @Override
    public Certificate[] getLocalCertificates() {
        return connection.getLocalCertificates();
    }

    @Override
    public Certificate[] getServerCertificates() throws SSLPeerUnverifiedException {
        try {
            return connection.getServerCertificates();
        } catch (Exception e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        }
    }

    @Override
    public Principal getPeerPrincipal() throws SSLPeerUnverifiedException {
        try {
            return connection.getPeerPrincipal();
        } catch (SSLPeerUnverifiedException e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        }
    }

    @Override
    public String getHeaderFieldKey(int n) {
        return connection.getHeaderFieldKey(n);
    }

    @Override
    public void setFixedLengthStreamingMode(int contentLength) {
        connection.setFixedLengthStreamingMode(contentLength);
    }

    @RequiresApi(api = Build.VERSION_CODES.KITKAT)
    @Override
    public void setFixedLengthStreamingMode(long contentLength) {
        connection.setFixedLengthStreamingMode(contentLength);
    }

    @Override
    public void setChunkedStreamingMode(int chunkLen) {
        connection.setChunkedStreamingMode(chunkLen);
    }

    @Override
    public String getHeaderField(int n) {
        return connection.getHeaderField(n);
    }

    @Override
    public void onReadCompleted(long count) {
        HashMap<String, String> responseHeaders = new HashMap<>();
        for (Entry entry : connection.getHeaderFields().entrySet()) {
            if (entry.getKey() != null) {
                responseHeaders.put(entry.getKey().toString(), entry.getValue().toString());
                if (entry.getKey().toString().equalsIgnoreCase(CONTENT_TYPE)) {
                    String contentType = entry.getValue().toString();
                    networkLogWrapper.setResponseContentType(contentType);
                }
            }
        }
        networkLogWrapper.setResponseHeaders(ObjectMapper.toJson(responseHeaders).toString());
        logNetworkCall(startTime, count, null);
    }


    @Override
    public void disconnect() {
        logNetworkCall(null);
        connection.disconnect();
    }

    @Override
    public boolean getAllowUserInteraction() {
        return connection.getAllowUserInteraction();
    }

    @Override
    public void setAllowUserInteraction(boolean allowUserInteraction) {
        connection.setAllowUserInteraction(allowUserInteraction);
    }

    @Override
    public int getConnectTimeout() {
        return connection.getConnectTimeout();
    }

    @Override
    public void setConnectTimeout(int timeout) {
        connection.setConnectTimeout(timeout);
    }

    @Override
    public Object getContent() throws IOException {
        try {
            return connection.getContent();
        } catch (IOException e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        }
    }

    @Override
    public Object getContent(Class[] classes) throws IOException {
        try {
            return connection.getContent(classes);
        } catch (IOException e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        }
    }

    @Override
    public int getContentLength() {
        int contentLength = connection.getContentLength();
        if (networkLogWrapper.getResponseBodySize() == 0L) {
            networkLogWrapper.setResponseBodySize(contentLength);
            logNetworkCall(null);
        }
        return contentLength;
    }

    @RequiresApi(api = Build.VERSION_CODES.N)
    @Override
    public long getContentLengthLong() {
        return connection.getContentLengthLong();
    }

    @Override
    public String getContentType() {
        return connection.getContentType();
    }

    @Override
    public long getDate() {
        return connection.getDate();
    }

    @Override
    public boolean getDefaultUseCaches() {
        return connection.getDefaultUseCaches();
    }

    @Override
    public void setDefaultUseCaches(boolean defaultUseCaches) {
        connection.setDefaultUseCaches(defaultUseCaches);
    }

    @Override
    public String getContentEncoding() {
        return connection.getContentEncoding();
    }

    @SuppressLint("ERADICATE_INCONSISTENT_SUBCLASS_RETURN_ANNOTATION")
    @Override
    public InputStream getInputStream() throws IOException {
        try {
            InputStream countingInputStream = countingInputStream(connection.getInputStream());
            return (countingInputStream != null) ? countingInputStream : connection.getInputStream();
        } catch (IOException e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        }
    }

    @Override
    public OutputStream getOutputStream() throws IOException {
        try {
            if (countableOutputStream == null) {
                countableOutputStream = new APMCountableOutputStream(connection.getOutputStream());
            }
            return countableOutputStream;
        } catch (IOException e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        }
    }

    @Override
    public void setRequestMethod(String method) throws ProtocolException {
        try {
            connection.setRequestMethod(method);
            networkLogWrapper.setMethod(method);
        } catch (ProtocolException e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        }
    }

    @Override
    public void setRequestProperty(@Nullable String key, @Nullable String value) {
        if (key != null) {
            requestHeaders.put(key, value);
            if (key.equalsIgnoreCase(CONTENT_TYPE)) {
                networkLogWrapper.setRequestContentType(value);
            }

            if (value != null) {
                connection.setRequestProperty(key, value);
            }
        }

    }

    @Override
    public void setDoOutput(boolean dooutput) {
        connection.setDoOutput(dooutput);
    }

    @Override
    public void setDoInput(boolean doinput) {
        connection.setDoInput(doinput);
    }

    @Override
    public boolean getDoInput() {
        return connection.getDoInput();
    }

    @Override
    public boolean getDoOutput() {
        return connection.getDoOutput();
    }


    @SuppressLint("ERADICATE_INCONSISTENT_SUBCLASS_RETURN_ANNOTATION")
    @Override
    @Nullable
    public InputStream getErrorStream() {

        InputStream parentStream;
        if (connection.getContentLength() > 0) {
            InputStream countingInputStream = countingInputStream(connection.getErrorStream());
            if (countingInputStream != null) {
                parentStream = countingInputStream;
            } else {
                parentStream = connection.getErrorStream();
            }
        } else {
            parentStream = connection.getErrorStream();
        }
        logNetworkCall(null);
        return parentStream;
    }

    @Override
    public long getExpiration() {
        return connection.getExpiration();
    }

    @Override
    public String getHeaderField(String name) {
        if (name != null)
            return connection.getHeaderField(name);
        return "";
    }

    @Override
    public int getHeaderFieldInt(String name, int Default) {
        return connection.getHeaderFieldInt(name, Default);
    }

    @Override
    public long getHeaderFieldDate(String name, long Default) {
        return connection.getHeaderFieldDate(name, Default);
    }

    @RequiresApi(api = Build.VERSION_CODES.N)
    @Override
    public long getHeaderFieldLong(String name, long Default) {
        return connection.getHeaderFieldLong(name, Default);
    }

    @Override
    public Map<String, List<String>> getHeaderFields() {
        return connection.getHeaderFields();
    }

    @Override
    public long getIfModifiedSince() {
        return connection.getIfModifiedSince();
    }

    @Override
    public void setIfModifiedSince(long ifModifiedSince) {
        connection.setIfModifiedSince(ifModifiedSince);
    }

    @Override
    public boolean getInstanceFollowRedirects() {
        return connection.getInstanceFollowRedirects();
    }

    @Override
    public void setInstanceFollowRedirects(boolean followRedirects) {
        connection.setInstanceFollowRedirects(followRedirects);
    }

    @Override
    public long getLastModified() {
        return connection.getLastModified();
    }

    @Override
    public Permission getPermission() throws IOException {
        try {
            return connection.getPermission();
        } catch (IOException e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        }
    }

    @Override
    public int getReadTimeout() {
        return connection.getReadTimeout();
    }

    @Override
    public void setReadTimeout(int timeout) {
        connection.setReadTimeout(timeout);
    }

    @Override
    public String getRequestMethod() {
        return connection.getRequestMethod();
    }

    @Override
    public Map<String, List<String>> getRequestProperties() {
        return connection.getRequestProperties();
    }

    @Override
    public String getRequestProperty(String key) {
        return connection.getRequestProperty(key);
    }

    @Nullable
    private InputStream countingInputStream(InputStream inputStream) {
        if (inputStream == null) {
            return null;
        }
        apmCountableInputStream = new APMCountableInputStream(inputStream, this);
        return apmCountableInputStream;
    }

    @Override
    public int getResponseCode() throws IOException {
        try {
            int responseCode = connection.getResponseCode();
            networkLogWrapper.setResponseCode(responseCode);
            logNetworkCall(null);
            return responseCode;
        } catch (IOException e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            throw e;
        }
    }

    @Override
    public String getResponseMessage() throws IOException {
        try {
            return connection.getResponseMessage();
        } catch (IOException e) {
            networkLogWrapper.setErrorMessage(e.getClass().getSimpleName());
            logNetworkCall(e);
            throw e;
        }
    }

    @Override
    public URL getURL() {
        return connection.getURL();
    }

    @Override
    public boolean getUseCaches() {
        return connection.getUseCaches();
    }

    @Override
    public void setUseCaches(boolean useCaches) {
        connection.setUseCaches(useCaches);
    }

    @Override
    @NonNull
    public String toString() {
        return connection.toString();
    }
}

