/*
 * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package ksp.org.jetbrains.kotlin.fir.analysis.jvm.checkers.expression

import ksp.org.jetbrains.kotlin.diagnostics.DiagnosticReporter
import ksp.org.jetbrains.kotlin.diagnostics.reportOn
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker
import ksp.org.jetbrains.kotlin.fir.analysis.diagnostics.jvm.FirJvmErrors
import ksp.org.jetbrains.kotlin.fir.declarations.FirAnonymousFunction
import ksp.org.jetbrains.kotlin.fir.declarations.FirFunction
import ksp.org.jetbrains.kotlin.fir.declarations.utils.isSuspend
import ksp.org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import ksp.org.jetbrains.kotlin.fir.expressions.resolvedArgumentMapping
import ksp.org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import ksp.org.jetbrains.kotlin.fir.resolve.transformers.unwrapAnonymousFunctionExpression
import ksp.org.jetbrains.kotlin.fir.types.FirResolvedTypeRef
import ksp.org.jetbrains.kotlin.fir.types.isSuspendOrKSuspendFunctionType
import ksp.org.jetbrains.kotlin.name.CallableId
import ksp.org.jetbrains.kotlin.name.FqName
import ksp.org.jetbrains.kotlin.name.Name
import ksp.org.jetbrains.kotlin.utils.addToStdlib.runIf

object FirJvmSuspensionPointInsideMutexLockChecker : FirFunctionCallChecker(MppCheckerKind.Common) {
    private val synchronizedCallableId = CallableId(FqName("kotlin"), Name.identifier("synchronized"))
    private val withLockCallableId = CallableId(FqName("kotlin.concurrent"), Name.identifier("withLock"))
    private val synchronizedBlockParamName = Name.identifier("block")

    context(context: CheckerContext, reporter: DiagnosticReporter)
    override fun check(expression: FirFunctionCall) {
        val symbol = expression.calleeReference.toResolvedCallableSymbol() ?: return
        if (!symbol.isSuspend) return
        var anonymousFunctionArg: FirAnonymousFunction? = null
        var isMutexLockFound = false
        var isSuspendFunctionFound = false

        for (element in context.containingElements.asReversed()) {
            if (element is FirFunctionCall) {
                val callableSymbol = element.calleeReference.toResolvedCallableSymbol() ?: continue
                val enclosingAnonymousFuncParam = element.resolvedArgumentMapping?.firstNotNullOfOrNull { entry ->
                    entry.key.unwrapAnonymousFunctionExpression()?.let {
                        runIf(it == anonymousFunctionArg) { entry.value }
                    }
                }

                if ((enclosingAnonymousFuncParam?.returnTypeRef as? FirResolvedTypeRef)?.coneType?.isSuspendOrKSuspendFunctionType(context.session) == true) {
                    isSuspendFunctionFound = true
                    break
                }

                if (callableSymbol.callableId == synchronizedCallableId &&
                    enclosingAnonymousFuncParam?.name == synchronizedBlockParamName ||
                    callableSymbol.callableId == withLockCallableId
                ) {
                    isMutexLockFound = true
                }
            } else if (element is FirFunction) {
                if (element.isSuspend) {
                    isSuspendFunctionFound = true
                    break
                }
                if (element is FirAnonymousFunction) {
                    anonymousFunctionArg = element // For anonymous function argument `isSuspend` can be detected from the respective parameter
                }
            }
        }

        // There is no need to report SUSPENSION_POINT_INSIDE_CRITICAL_SECTION if enclosing suspend function is not found
        // Because ILLEGAL_SUSPEND_FUNCTION_CALL is reported in this case
        if (isMutexLockFound && isSuspendFunctionFound) {
            reporter.reportOn(expression.source, FirJvmErrors.SUSPENSION_POINT_INSIDE_CRITICAL_SECTION, symbol)
        }
    }
}
