package com.unity3d.ads.gatewayclient

import com.google.protobuf.InvalidProtocolBufferException
import com.unity3d.ads.core.data.model.OperationType
import com.unity3d.ads.core.data.model.exception.NetworkTimeoutException
import com.unity3d.ads.core.data.model.exception.UnityAdsNetworkException
import com.unity3d.ads.core.data.repository.SessionRepository
import com.unity3d.ads.core.domain.HandleGatewayUniversalResponse
import com.unity3d.ads.core.domain.SendDiagnosticEvent
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.NETWORK_CLIENT
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.NETWORK_FAILURE
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.NETWORK_PARSE
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.NETWORK_SUCCESS
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.OPERATION
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.PROTOCOL
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.REASON
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.REASON_CODE
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.REASON_DEBUG
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.REASON_PROTOBUF_PARSING
import com.unity3d.ads.core.domain.SendDiagnosticEvent.Companion.RETRIES
import com.unity3d.ads.core.extensions.elapsedMillis
import com.unity3d.services.UnityAdsConstants
import com.unity3d.services.core.log.DeviceLog
import gatewayprotocol.v1.UniversalResponseOuterClass.UniversalResponse
import gatewayprotocol.v1.UniversalRequestOuterClass.UniversalRequest
import com.unity3d.services.core.network.core.HttpClient
import com.unity3d.services.core.network.model.HttpRequest
import com.unity3d.services.core.network.model.HttpResponse
import com.unity3d.services.core.network.model.RequestType
import com.unity3d.services.core.network.model.isSuccessful
import com.unity3d.services.core.network.model.toHttpResponse
import gatewayprotocol.v1.*
import gatewayprotocol.v1.universalResponse
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.delay
import kotlinx.coroutines.withTimeout
import kotlin.math.min

import kotlin.math.pow
import kotlin.random.Random
import kotlin.time.ExperimentalTime
import kotlin.time.TimeMark
import kotlin.time.TimeSource

@OptIn(ExperimentalTime::class)

class CommonGatewayClient(
    private val httpClient: HttpClient,
    private val handleGatewayUniversalResponse: HandleGatewayUniversalResponse,
    private val sendDiagnosticEvent: SendDiagnosticEvent,
    private val sessionRepository: SessionRepository,
): GatewayClient {

    override suspend fun request(
        url: String,
        request: UniversalRequest,
        requestPolicy: RequestPolicy,
        operationType: OperationType
    ): UniversalResponse {
        return withTimeout(requestPolicy.maxDuration.toLong()) {
            executeWithRetry(url, request, requestPolicy, operationType)
        }
    }

    private suspend fun executeWithRetry(url: String,
                                         request: UniversalRequest,
                                         requestPolicy: RequestPolicy,
                                         operationType: OperationType): UniversalResponse {
        var retryCount = 0
        val gatewayUrl = getGatewayUrl(url)
        val timer = TimeSource.Monotonic.markNow()
        var delayTime = requestPolicy.retryWaitBase.toLong()
        do {
            val headers = getHeaders(retryCount)
            val httpRequest = buildHttpRequest(gatewayUrl, headers, requestPolicy, request)
            val httpResponse = executeRequest(httpRequest, retryCount, operationType)

            if (httpResponse.isSuccessful()) {
                return getUniversalResponse(
                    httpResponse,
                    operationType
                ).also { handleGatewayUniversalResponse(it) }
            }

            delayTime = calculateDelayTime(delayTime, requestPolicy, retryCount)
            val currentDuration = timer.elapsedMillis().toLong()
            val durationWithDelay = currentDuration + delayTime

            if (!shouldRetry(
                    httpResponse.statusCode,
                    durationWithDelay,
                    requestPolicy.maxDuration
                )
            ) {
                throw NetworkTimeoutException("Gateway request failed after $retryCount retries  currentDuration: ${currentDuration}ms maxDuration: ${requestPolicy.maxDuration}ms")
            }

            try {
                delay(delayTime)
            }
            catch (e: TimeoutCancellationException) {
                throw NetworkTimeoutException("Gateway was canceled while waiting for next request, after $retryCount retries currentDuration: ${timer.elapsedMillis().toLong()}ms maxDuration: ${requestPolicy.maxDuration}ms")
            }
            retryCount++
        } while (true)
    }

    private suspend fun executeRequest(
        httpRequest: HttpRequest,
        retryCount: Int,
        operationType: OperationType
    ): HttpResponse {
        val startTime = TimeSource.Monotonic.markNow()
        var httpResponse: HttpResponse
        try {
            httpResponse = httpClient.execute(httpRequest)
            sendNetworkSuccessDiagnosticEvent(httpResponse, retryCount, operationType, startTime)
        } catch (e: UnityAdsNetworkException) {
            sendNetworkErrorDiagnosticEvent(e, retryCount, operationType, startTime)
            httpResponse = e.toHttpResponse()
        } catch (e: TimeoutCancellationException) {
            val unityAdsNetworkException = NetworkTimeoutException("Gateway request was canceled due to exceeding timeout for operation")
            sendNetworkErrorDiagnosticEvent(unityAdsNetworkException, retryCount, operationType, startTime)
            throw unityAdsNetworkException
        }
        return httpResponse
    }

    private fun buildHttpRequest(
        gatewayUrl: String,
        headers: Map<String, List<String>>,
        requestPolicy: RequestPolicy,
        request: UniversalRequest
    ): HttpRequest {
        return HttpRequest(
            baseURL = gatewayUrl,
            method = RequestType.POST,
            body = request.toByteArray(),
            headers = headers,
            connectTimeout = requestPolicy.connectTimeout,
            readTimeout = requestPolicy.readTimeout,
            writeTimeout = requestPolicy.writeTimeout,
            callTimeout = requestPolicy.overallTimeout,
            isProtobuf = true
        )
    }

    private fun getHeaders(retryCount: Int): Map<String, List<String>> {
        return buildMap {
            put(HEADER_CONTENT_TYPE, listOf(HEADER_PROTOBUF))
            if (retryCount > 0) {
                put(HEADER_RETRY_ATTEMPT, listOf(retryCount.toString()))
            }
        }
    }

    private fun getGatewayUrl(url: String): String {
        return if (url != UnityAdsConstants.DefaultUrls.GATEWAY_URL) {
            url
        } else {
            sessionRepository.gatewayUrl
        }
    }

    private fun sendNetworkErrorDiagnosticEvent(
        e: UnityAdsNetworkException,
        retryCount: Int,
        operationType: OperationType,
        startTime: TimeMark,
    ) {
       if (operationType == OperationType.UNIVERSAL_EVENT) return

       val tags = mutableMapOf(
            OPERATION to operationType.toString(),
            RETRIES to retryCount.toString(),
            PROTOCOL to e.protocol.toString(),
            NETWORK_CLIENT to e.client.toString(),
            REASON_CODE to e.code.toString(),
            REASON_DEBUG to e.message
       )
        sendDiagnosticEvent(NETWORK_FAILURE, startTime.elapsedMillis(), tags = tags)
    }

    private fun sendNetworkSuccessDiagnosticEvent(
        httpResponse: HttpResponse,
        retryCount: Int,
        operationType: OperationType,
        startTime: TimeMark,
    ) {
        if (operationType == OperationType.UNIVERSAL_EVENT) return

        val tags = mutableMapOf(
            OPERATION to operationType.toString(),
            RETRIES to retryCount.toString(),
            PROTOCOL to httpResponse.protocol,
            NETWORK_CLIENT to httpResponse.client,
            REASON_CODE to httpResponse.statusCode.toString(),
        )
        sendDiagnosticEvent(NETWORK_SUCCESS, startTime.elapsedMillis(), tags = tags)
    }

    private fun getUniversalResponse(response: HttpResponse, operationType: OperationType): UniversalResponse {
        try {
            val responseBody = response.body
            if (responseBody is ByteArray) {
                return UniversalResponse.parseFrom(responseBody)
            }
            if (responseBody is String) {
                return UniversalResponse.parseFrom(
                    responseBody.toByteArray(Charsets.UTF_8)
                )
            }
            throw InvalidProtocolBufferException("Could not parse response from gateway service")

        } catch (e: InvalidProtocolBufferException) {
            DeviceLog.debug("Failed to parse response from gateway service with exception: %s", e.localizedMessage)
            sendDiagnosticEvent(
                NETWORK_PARSE,
                tags = mapOf(
                    OPERATION to operationType.toString(),
                    REASON to REASON_PROTOBUF_PARSING,
                    REASON_DEBUG to response.body.toString()
                )
            )
            return universalResponse {
                error = error {
                    errorText = "ERROR: Could not parse response from gateway service"
                }
            }
        }
    }

    private fun calculateDelayTime(currentDelay: Long, requestPolicy: RequestPolicy, retryCount: Int): Long {
        val retryWaitTime = calculateExponentialBackoff(currentDelay, requestPolicy, retryCount)
        val jitter = calculateJitter(requestPolicy.retryWaitBase, requestPolicy.retryJitterPct)
        return min(retryWaitTime + jitter, requestPolicy.retryMaxInterval.toLong())
    }

    private fun calculateExponentialBackoff(currentDelay: Long, requestPolicy: RequestPolicy, retryCount: Int): Long {
        // On first retry (retryCount is 0) we don't want to scale the delay
        if (retryCount == 0) return currentDelay
        return (currentDelay * requestPolicy.retryScalingFactor).toLong()
    }

    private fun calculateJitter(retryWaitBase: Int, retryJitterPct: Float): Long {
        if (retryJitterPct == 0f) return 0
        val jitterRange = (retryWaitBase * retryJitterPct).toLong()
        return Random.nextLong(-jitterRange, jitterRange)
    }

    private fun shouldRetry(responseCode: Int, duration: Long, maxDuration: Int): Boolean {
        return responseCode in CODE_400..CODE_599 && duration < maxDuration
    }
    companion object {
        const val HEADER_RETRY_ATTEMPT = "X-RETRY-ATTEMPT"
        const val HEADER_CONTENT_TYPE = "Content-Type"
        const val HEADER_PROTOBUF = "application/x-protobuf"
        const val CODE_400 = 400
        const val CODE_599 = 599
    }
}