/*
 * Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
 */

package io.ktor.client.plugins.logging

import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.api.*
import io.ktor.client.plugins.observer.*
import io.ktor.client.request.*
import io.ktor.client.request.forms.*
import io.ktor.client.statement.*
import io.ktor.client.utils.*
import io.ktor.http.*
import io.ktor.http.content.*
import io.ktor.util.*
import io.ktor.util.pipeline.*
import io.ktor.utils.io.*
import io.ktor.utils.io.charsets.*
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.launch

private val ClientCallLogger = AttributeKey<HttpClientCallLogger>("CallLogger")
private val DisableLogging = AttributeKey<Unit>("DisableLogging")

public enum class LoggingFormat {
    Default,

    /**
     * [OkHttp logging format](https://github.com/square/okhttp/blob/parent-4.12.0/okhttp-logging-interceptor/src/main/kotlin/okhttp3/logging/HttpLoggingInterceptor.kt#L48-L105).
     * Writes only application-level logs because the low-level HTTP communication is hidden within the engine implementations.
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.LoggingFormat.OkHttp)
     */
    OkHttp
}

/**
 * A configuration for the [Logging] plugin.
 *
 * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.LoggingConfig)
 */
@KtorDsl
public class LoggingConfig {
    internal var filters = mutableListOf<(HttpRequestBuilder) -> Boolean>()
    internal val sanitizedHeaders = mutableListOf<SanitizedHeader>()

    private var _logger: Logger? = null

    /**
     * A general format for logging requests and responses. See [LoggingFormat].
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.LoggingConfig.format)
     */
    public var format: LoggingFormat = LoggingFormat.Default

    /**
     * Specifies a [Logger] instance.
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.LoggingConfig.logger)
     */
    public var logger: Logger
        get() = _logger ?: Logger.DEFAULT
        set(value) {
            _logger = value
        }

    /**
     * Specifies the logging level.
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.LoggingConfig.level)
     */
    public var level: LogLevel = LogLevel.HEADERS

    /**
     * Configures the filter applied to the response body when logging.
     *
     * The `bodyFilter` property specifies the logic used to selectively log, modify,
     * or exclude the body of HTTP responses. It uses a [LogBodyFilter] implementation
     * to determine how the response body is included in the logs. By default, the filter
     * is set to [BinaryLogBodyFilter].
     *
     * The associated filter can be customized to handle specific logging requirements
     * such as hiding sensitive data, truncating long responses, or modifying the
     * format of the logged body content.
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.LoggingConfig.bodyFilter)
     */
    public var bodyFilter: LogBodyFilter = BinaryLogBodyFilter

    /**
     * Allows you to filter log messages for calls matching a [predicate].
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.LoggingConfig.filter)
     */
    public fun filter(predicate: (HttpRequestBuilder) -> Boolean) {
        filters.add(predicate)
    }

    /**
     * Allows you to sanitize sensitive headers to avoid their values appearing in the logs.
     * In the example below, Authorization header value will be replaced with '***' when logging:
     * ```kotlin
     * sanitizeHeader { header -> header == HttpHeaders.Authorization }
     * ```
     *
     * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.LoggingConfig.sanitizeHeader)
     */
    public fun sanitizeHeader(placeholder: String = "***", predicate: (String) -> Boolean) {
        sanitizedHeaders.add(SanitizedHeader(placeholder, predicate))
    }
}

/**
 * A client's plugin that provides the capability to log HTTP calls.
 *
 * You can learn more from [Logging](https://ktor.io/docs/client-logging.html).
 *
 * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.Logging)
 */
@OptIn(InternalAPI::class, DelicateCoroutinesApi::class)
public val Logging: ClientPlugin<LoggingConfig> = createClientPlugin("Logging", ::LoggingConfig) {
    val logger: Logger = pluginConfig.logger
    val level: LogLevel = pluginConfig.level
    val filters: List<(HttpRequestBuilder) -> Boolean> = pluginConfig.filters
    val sanitizedHeaders: List<SanitizedHeader> = pluginConfig.sanitizedHeaders
    val okHttpFormat = pluginConfig.format == LoggingFormat.OkHttp
    val bodyFilter: LogBodyFilter = pluginConfig.bodyFilter

    fun shouldBeLogged(request: HttpRequestBuilder): Boolean = filters.isEmpty() || filters.any { it(request) }

    fun isNone(): Boolean = level == LogLevel.NONE
    fun isInfo(): Boolean = level == LogLevel.INFO
    fun isHeaders(): Boolean = level == LogLevel.HEADERS
    fun isBody(): Boolean = level == LogLevel.BODY || level == LogLevel.ALL

    suspend fun logRequestBody(
        url: Url,
        content: OutgoingContent,
        contentLength: Long?,
        headers: Headers,
        method: HttpMethod,
        logLines: MutableList<String>,
        body: ByteReadChannel
    ) {
        val filteredBody = bodyFilter.filterRequest(
            url,
            contentLength,
            content.contentType,
            headers,
            body,
        )
        when (filteredBody) {
            is BodyFilterResult.Empty -> {
                logLines.add("--> END ${method.value} (0-byte body)")
            }

            is BodyFilterResult.Skip -> {
                logLines.add(
                    buildString {
                        append("--> END ${method.value} (")
                        filteredBody.reason?.let {
                            append("$it ")
                        }
                        filteredBody.byteSize?.let {
                            append("$it-byte ")
                        }
                        append("body omitted)")
                    }
                )
            }

            is BodyFilterResult.Content -> {
                logLines.add(filteredBody.read())
                logLines.add("--> END ${method.value} (${filteredBody.byteSize}-byte body)")
            }
        }
    }

    suspend fun logOutgoingContent(
        url: Url,
        content: OutgoingContent,
        method: HttpMethod,
        headers: Headers,
        logLines: MutableList<String>,
        process: (ByteReadChannel) -> ByteReadChannel = { it }
    ): OutgoingContent? {
        return when (content) {
            is io.ktor.client.content.ObservableContent -> {
                logOutgoingContent(url, content.delegate, method, headers, logLines, process)
            }

            is MultiPartFormDataContent -> {
                for (part in content.parts) {
                    logLines.add("--${content.boundary}")
                    for ((key, values) in part.headers.entries()) {
                        logLines.add("$key: ${values.joinToString("; ")}")
                    }

                    if (part is PartData.FormItem) {
                        logLines.add("${HttpHeaders.ContentLength}: ${part.value.length}")
                        logLines.add("")
                        logLines.add(part.value)
                    } else {
                        logLines.add("")
                        val contentLength = part.headers[HttpHeaders.ContentLength]
                        if (contentLength != null) {
                            logLines.add("binary $contentLength-byte body omitted")
                        } else {
                            logLines.add("binary body omitted")
                        }
                    }
                }

                logLines.add("--${content.boundary}--")
                logLines.add("--> END ${method.value}")

                null
            }
            is OutgoingContent.ByteArrayContent -> {
                val bytes = content.bytes()
                logRequestBody(url, content, bytes.size.toLong(), headers, method, logLines, ByteReadChannel(bytes))
                null
            }

            is OutgoingContent.ContentWrapper -> {
                logOutgoingContent(url, content.delegate(), method, headers, logLines, process)
            }

            is OutgoingContent.NoContent -> {
                logLines.add("--> END ${method.value}")
                null
            }

            is OutgoingContent.ProtocolUpgrade -> {
                logLines.add("--> END ${method.value}")
                null
            }

            is OutgoingContent.ReadChannelContent -> {
                val (origChannel, newChannel) = content.readFrom().split(client)
                logRequestBody(url, content, content.contentLength, headers, method, logLines, newChannel)
                LoggedContent(content, origChannel)
            }

            is OutgoingContent.WriteChannelContent -> {
                val channel = ByteChannel()

                client.launch {
                    content.writeTo(channel)
                    channel.close()
                }

                val (origChannel, newChannel) = channel.split(client)
                logRequestBody(url, content, content.contentLength, headers, method, logLines, newChannel)
                LoggedContent(content, origChannel)
            }
        }
    }

    suspend fun logRequestOkHttpFormat(request: HttpRequestBuilder, logLines: MutableList<String>): OutgoingContent? {
        if (isNone()) return null

        val uri = URLBuilder().takeFrom(request.url).build().pathQuery()
        val body = request.body
        val headers = HeadersBuilder().apply {
            if (body is OutgoingContent &&
                request.method != HttpMethod.Get &&
                request.method != HttpMethod.Head &&
                body !is EmptyContent
            ) {
                body.contentType?.let {
                    appendIfNameAbsent(HttpHeaders.ContentType, it.toString())
                }
                body.contentLength?.let {
                    appendIfNameAbsent(HttpHeaders.ContentLength, it.toString())
                }
            }
            appendAll(request.headers)
        }.build()

        val contentLength = headers[HttpHeaders.ContentLength]?.toLongOrNull()
        val startLine = when {
            (request.method == HttpMethod.Get) ||
                (request.method == HttpMethod.Head) ||
                ((isHeaders() || isBody()) && contentLength != null) ||
                (isHeaders() && contentLength == null) ||
                headers.contains(HttpHeaders.ContentEncoding) -> "--> ${request.method.value} $uri"

            isInfo() && contentLength != null -> "--> ${request.method.value} $uri ($contentLength-byte body)"

            body is OutgoingContent.WriteChannelContent ||
                body is OutgoingContent.ReadChannelContent -> "--> ${request.method.value} $uri (unknown-byte body)"

            else -> {
                val size = computeRequestBodySize(request.body)
                "--> ${request.method.value} $uri ($size-byte body)"
            }
        }

        logLines.add(startLine)

        if (!isHeaders() && !isBody()) {
            return null
        }

        for ((name, values) in headers.entries()) {
            if (sanitizedHeaders.find { sh -> sh.predicate(name) } == null) {
                logLines.add("$name: ${values.joinToString(separator = ", ")}")
            } else {
                logLines.add("$name: ██")
            }
        }

        if (!isBody() || request.method == HttpMethod.Get || request.method == HttpMethod.Head) {
            logLines.add("--> END ${request.method.value}")
            return null
        }

        logLines.add("")

        if (body !is OutgoingContent) {
            logLines.add("--> END ${request.method.value}")
            return null
        }

        val newContent = if (request.headers[HttpHeaders.ContentEncoding] == "gzip") {
            logOutgoingContent(request.url.build(), body, request.method, headers, logLines) { channel ->
                GZipEncoder.decode(channel)
            }
        } else {
            logOutgoingContent(request.url.build(), body, request.method, headers, logLines)
        }

        return newContent
    }

    suspend fun logResponseBody(response: HttpResponse, body: ByteReadChannel, logLines: MutableList<String>) {
        logLines.add("")

        val filteredBody = bodyFilter.filterResponse(
            response.call.request.url,
            response.contentLength(),
            response.contentType(),
            response.headers,
            body,
        )
        val duration = response.responseTime.timestamp - response.requestTime.timestamp

        when (filteredBody) {
            is BodyFilterResult.Empty -> {
                logLines.add("<-- END HTTP (${duration}ms, 0-byte body)")
            }

            is BodyFilterResult.Skip -> {
                logLines.add(
                    buildString {
                        append("<-- END HTTP (")
                        append(duration)
                        append("ms, ")
                        filteredBody.reason?.let {
                            append("$it ")
                        }
                        filteredBody.byteSize?.let {
                            append("$it-byte ")
                        }
                        append("body omitted)")
                    }
                )
            }

            is BodyFilterResult.Content -> {
                logLines.add(filteredBody.read())
                logLines.add("<-- END HTTP (${duration}ms, ${filteredBody.byteSize}-byte body)")
            }
        }
    }

    suspend fun logResponseOkHttpFormat(response: HttpResponse, logLines: MutableList<String>): HttpResponse {
        if (isNone()) return response

        val contentLength = response.headers[HttpHeaders.ContentLength]?.toLongOrNull()
        val request = response.request
        val duration = response.responseTime.timestamp - response.requestTime.timestamp

        val startLine = when {
            response.headers[HttpHeaders.TransferEncoding] == "chunked" &&
                (isInfo() || isHeaders()) ->
                "<-- ${response.status} ${request.url.pathQuery()} (${duration}ms, unknown-byte body)"

            isInfo() && contentLength != null ->
                "<-- ${response.status} ${request.url.pathQuery()} (${duration}ms, $contentLength-byte body)"

            isBody() ||
                (isInfo() && contentLength == null) ||
                (isHeaders() && contentLength != null) ||
                (response.headers[HttpHeaders.ContentEncoding] == "gzip") ->
                "<-- ${response.status} ${request.url.pathQuery()} (${duration}ms)"

            else -> "<-- ${response.status} ${request.url.pathQuery()} (${duration}ms, unknown-byte body)"
        }

        logLines.add(startLine)

        if (!isHeaders() && !isBody()) {
            return response
        }

        for ((name, values) in response.headers.entries()) {
            if (sanitizedHeaders.find { sh -> sh.predicate(name) } == null) {
                logLines.add("$name: ${values.joinToString(separator = ", ")}")
            } else {
                logLines.add("$name: ██")
            }
        }

        if (!isBody()) {
            logLines.add("<-- END HTTP")
            return response
        }

        if (contentLength != null && contentLength == 0L) {
            logLines.add("<-- END HTTP (${duration}ms, $contentLength-byte body)")
            return response
        }

        if (response.contentType() == ContentType.Text.EventStream) {
            logLines.add("<-- END HTTP (streaming)")
            return response
        }

        if (response.isSaved) {
            logResponseBody(response, response.rawContent, logLines)
            return response
        }

        val (origChannel, newChannel) = response.rawContent.split(response)

        logResponseBody(response, newChannel, logLines)

        val call = response.call.replaceResponse { origChannel }
        return call.response
    }

    @OptIn(DelicateCoroutinesApi::class)
    suspend fun logRequestBody(
        content: OutgoingContent,
        logger: HttpClientCallLogger
    ): OutgoingContent {
        val requestLog = StringBuilder()
        requestLog.appendLine("BODY Content-Type: ${content.contentType}")

        val charset = content.contentType?.charset() ?: Charsets.UTF_8

        val channel = ByteChannel()
        GlobalScope.launch(Dispatchers.Default + MDCContext()) {
            try {
                val text = channel.tryReadText(charset) ?: "[request body omitted]"
                requestLog.appendLine("BODY START")
                requestLog.appendLine(text)
                requestLog.append("BODY END")
            } finally {
                logger.logRequest(requestLog.toString())
                logger.closeRequestLog()
            }
        }

        return content.observe(channel)
    }

    fun logRequestException(context: HttpRequestBuilder, cause: Throwable) {
        if (level.info) {
            logger.log("REQUEST ${Url(context.url)} failed with exception: $cause")
        }
    }

    suspend fun logRequest(request: HttpRequestBuilder): OutgoingContent? {
        val content = request.body as OutgoingContent
        val callLogger = HttpClientCallLogger(logger)
        request.attributes.put(ClientCallLogger, callLogger)

        val message = buildString {
            if (level.info) {
                appendLine("REQUEST: ${Url(request.url)}")
                appendLine("METHOD: ${request.method}")
            }

            if (level.headers) {
                appendLine("COMMON HEADERS")
                logHeaders(request.headers.entries(), sanitizedHeaders)

                appendLine("CONTENT HEADERS")
                val contentLengthPlaceholder = sanitizedHeaders
                    .firstOrNull { it.predicate(HttpHeaders.ContentLength) }
                    ?.placeholder
                val contentTypePlaceholder = sanitizedHeaders
                    .firstOrNull { it.predicate(HttpHeaders.ContentType) }
                    ?.placeholder
                content.contentLength?.let {
                    logHeader(HttpHeaders.ContentLength, contentLengthPlaceholder ?: it.toString())
                }
                content.contentType?.let {
                    logHeader(HttpHeaders.ContentType, contentTypePlaceholder ?: it.toString())
                }
                logHeaders(content.headers.entries(), sanitizedHeaders)
            }
        }

        if (message.isNotEmpty()) {
            callLogger.logRequest(message)
        }

        if (message.isEmpty() || !level.body) {
            callLogger.closeRequestLog()
            return null
        }

        return logRequestBody(content, callLogger)
    }

    fun logResponseException(log: StringBuilder, request: HttpRequest, cause: Throwable) {
        if (!level.info) return
        log.append("RESPONSE ${request.url} failed with exception: $cause")
    }

    on(SendHook) { request ->
        if (!shouldBeLogged(request)) {
            request.attributes.put(DisableLogging, Unit)
            return@on
        }

        if (okHttpFormat) {
            val requestLogLines = mutableListOf<String>()
            val content = logRequestOkHttpFormat(request, requestLogLines)

            if (requestLogLines.size > 0) {
                logger.log(requestLogLines.joinToString(separator = "\n"))
            }

            try {
                if (content != null) {
                    proceedWith(content)
                } else {
                    proceed()
                }
            } catch (cause: Throwable) {
                logger.log("<-- HTTP FAILED: $cause")
                throw cause
            }

            return@on
        }

        val loggedRequest = try {
            logRequest(request)
        } catch (_: Throwable) {
            null
        }

        try {
            proceedWith(loggedRequest ?: request.body)
        } catch (cause: Throwable) {
            logRequestException(request, cause)
            throw cause
        } finally {
        }
    }

    on(ResponseAfterEncodingHook) { response ->
        if (okHttpFormat) {
            val responseLogLines = mutableListOf<String>()
            val newResponse = logResponseOkHttpFormat(response, responseLogLines)

            if (responseLogLines.size > 0) {
                logger.log(responseLogLines.joinToString(separator = "\n"))
            }

            if (newResponse != response) {
                proceedWith(newResponse)
            }
        }
    }

    on(ResponseHook) { response ->
        if (okHttpFormat) return@on

        if (level == LogLevel.NONE || response.call.attributes.contains(DisableLogging)) return@on

        val callLogger = response.call.attributes[ClientCallLogger]
        val header = StringBuilder()

        var failed = false
        try {
            logResponseHeader(header, response.call.response, level, sanitizedHeaders)
            proceed()
        } catch (cause: Throwable) {
            logResponseException(header, response.call.request, cause)
            failed = true
            throw cause
        } finally {
            callLogger.logResponseHeader(header.toString())
            if (failed || !level.body) {
                callLogger.closeResponseLog()
            } else if (level.body && response.isSaved) {
                // Log only saved response body here. Streaming responses are logged via ResponseObserver
                callLogger.logResponseBody(response)
                callLogger.closeResponseLog()
            }
        }
    }

    on(ReceiveHook) { call ->
        if (okHttpFormat) return@on

        if (level == LogLevel.NONE || call.attributes.contains(DisableLogging)) {
            return@on
        }

        try {
            proceed()
        } catch (cause: Throwable) {
            val log = StringBuilder()
            val callLogger = call.attributes[ClientCallLogger]
            logResponseException(log, call.request, cause)
            callLogger.logResponseException(log.toString())
            callLogger.closeResponseLog()
            throw cause
        }
    }

    if (okHttpFormat) return@createClientPlugin

    if (!level.body) return@createClientPlugin

    val responseObserver = ResponseObserver.prepare {
        // Use observer to log streaming responses (responses that aren't saved in memory)
        filter { !it.response.isSaved }

        onResponse { response ->
            if (level == LogLevel.NONE || response.call.attributes.contains(DisableLogging)) return@onResponse

            val callLogger = response.call.attributes[ClientCallLogger]
            callLogger.logResponseBody(response)
            callLogger.closeResponseLog()
        }
    }
    ResponseObserver.install(responseObserver, client)
}

private fun Url.pathQuery(): String {
    return buildString {
        if (encodedPath.isEmpty()) {
            append("/")
        } else {
            append(encodedPath)
        }

        if (!encodedQuery.isEmpty()) {
            append("?")
            append(encodedQuery)
        }
    }
}

private fun computeRequestBodySize(content: Any): Long {
    check(content is OutgoingContent)

    return when (content) {
        is OutgoingContent.ByteArrayContent -> content.bytes().size.toLong()
        is OutgoingContent.ContentWrapper -> computeRequestBodySize(content.delegate())
        is OutgoingContent.NoContent -> 0
        is OutgoingContent.ProtocolUpgrade -> 0
        else -> error("Unable to calculate the size for type ${content::class.simpleName}")
    }
}

/**
 * Configures and installs [Logging] in [HttpClient].
 *
 * [Report a problem](https://ktor.io/feedback/?fqname=io.ktor.client.plugins.logging.Logging)
 */
@Suppress("FunctionName")
public fun HttpClientConfig<*>.Logging(block: LoggingConfig.() -> Unit = {}) {
    install(Logging, block)
}

internal class SanitizedHeader(
    val placeholder: String,
    val predicate: (String) -> Boolean
)

private object ResponseHook : ClientHook<suspend ResponseHook.Context.(response: HttpResponse) -> Unit> {

    class Context(private val context: PipelineContext<HttpResponse, Unit>) {
        suspend fun proceed() = context.proceed()
    }

    override fun install(
        client: HttpClient,
        handler: suspend Context.(response: HttpResponse) -> Unit
    ) {
        client.receivePipeline.intercept(HttpReceivePipeline.State) {
            handler(Context(this), subject)
        }
    }
}

private object ResponseAfterEncodingHook :
    ClientHook<suspend ResponseAfterEncodingHook.Context.(response: HttpResponse) -> Unit> {

    class Context(private val context: PipelineContext<HttpResponse, Unit>) {
        suspend fun proceedWith(response: HttpResponse) = context.proceedWith(response)
    }

    override fun install(
        client: HttpClient,
        handler: suspend Context.(response: HttpResponse) -> Unit
    ) {
        val afterState = PipelinePhase("AfterState")
        client.receivePipeline.insertPhaseAfter(HttpReceivePipeline.State, afterState)
        client.receivePipeline.intercept(afterState) {
            handler(Context(this), subject)
        }
    }
}

private object SendHook : ClientHook<suspend SendHook.Context.(response: HttpRequestBuilder) -> Unit> {

    class Context(private val context: PipelineContext<Any, HttpRequestBuilder>) {
        suspend fun proceedWith(content: Any) = context.proceedWith(content)
        suspend fun proceed() = context.proceed()
    }

    override fun install(
        client: HttpClient,
        handler: suspend Context.(request: HttpRequestBuilder) -> Unit
    ) {
        client.sendPipeline.intercept(HttpSendPipeline.Monitoring) {
            handler(Context(this), context)
        }
    }
}

private object ReceiveHook : ClientHook<suspend ReceiveHook.Context.(call: HttpClientCall) -> Unit> {

    class Context(private val context: PipelineContext<HttpResponseContainer, HttpClientCall>) {
        suspend fun proceed() = context.proceed()
    }

    override fun install(
        client: HttpClient,
        handler: suspend Context.(call: HttpClientCall) -> Unit
    ) {
        client.responsePipeline.intercept(HttpResponsePipeline.Receive) {
            handler(Context(this), context)
        }
    }
}
