/*
 * Copyright 2010-2022 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package ksp.org.jetbrains.kotlin.fir.analysis.native.checkers

import ksp.org.jetbrains.kotlin.builtins.StandardNames
import ksp.org.jetbrains.kotlin.diagnostics.DiagnosticReporter
import ksp.org.jetbrains.kotlin.diagnostics.reportOn
import ksp.org.jetbrains.kotlin.fir.FirElement
import ksp.org.jetbrains.kotlin.fir.FirSession
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirBasicDeclarationChecker
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.extractClassFromArgument
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.hasModifier
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.unsubstitutedScope
import ksp.org.jetbrains.kotlin.fir.analysis.diagnostics.native.FirNativeErrors
import ksp.org.jetbrains.kotlin.fir.containingClassLookupTag
import ksp.org.jetbrains.kotlin.fir.declarations.*
import ksp.org.jetbrains.kotlin.fir.declarations.utils.isExpect
import ksp.org.jetbrains.kotlin.fir.expressions.*
import ksp.org.jetbrains.kotlin.fir.isSubstitutionOrIntersectionOverride
import ksp.org.jetbrains.kotlin.fir.references.isError
import ksp.org.jetbrains.kotlin.fir.resolve.toRegularClassSymbol
import ksp.org.jetbrains.kotlin.fir.scopes.MemberWithBaseScope
import ksp.org.jetbrains.kotlin.fir.scopes.ScopeFunctionRequiresPrewarm
import ksp.org.jetbrains.kotlin.fir.scopes.getDirectOverriddenFunctionsWithBaseScope
import ksp.org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
import ksp.org.jetbrains.kotlin.fir.symbols.impl.FirTypeAliasSymbol
import ksp.org.jetbrains.kotlin.fir.types.hasError
import ksp.org.jetbrains.kotlin.lexer.KtTokens
import ksp.org.jetbrains.kotlin.name.ClassId
import ksp.org.jetbrains.kotlin.name.FqName
import ksp.org.jetbrains.kotlin.resolve.annotations.KOTLIN_THROWS_ANNOTATION_FQ_NAME
import ksp.org.jetbrains.kotlin.utils.addToStdlib.runUnless

sealed class FirNativeThrowsChecker(mppKind: MppCheckerKind) : FirBasicDeclarationChecker(mppKind) {
    object Regular : FirNativeThrowsChecker(MppCheckerKind.Platform) {
        context(context: CheckerContext, reporter: DiagnosticReporter)
        override fun check(declaration: FirDeclaration) {
            if ((declaration as? FirMemberDeclaration)?.isExpect == true) return
            super.check(declaration)
        }
    }

    object ForExpectClass : FirNativeThrowsChecker(MppCheckerKind.Common) {
        context(context: CheckerContext, reporter: DiagnosticReporter)
        override fun check(declaration: FirDeclaration) {
            if ((declaration as? FirMemberDeclaration)?.isExpect != true) return
            super.check(declaration)
        }
    }

    companion object {
        private val throwsClassId = ClassId.topLevel(KOTLIN_THROWS_ANNOTATION_FQ_NAME)

        private val cancellationExceptionFqName = FqName("kotlin.coroutines.cancellation.CancellationException")

        private val cancellationExceptionAndSupersClassIds = setOf(
            ClassId.topLevel(StandardNames.FqNames.throwable),
            ClassId.topLevel(FqName("kotlin.Exception")),
            ClassId.topLevel(FqName("kotlin.RuntimeException")),
            ClassId.topLevel(FqName("kotlin.IllegalStateException")),
            ClassId.topLevel(cancellationExceptionFqName)
        )
    }

    context(context: CheckerContext, reporter: DiagnosticReporter)
    override fun check(declaration: FirDeclaration) {
        val throwsAnnotation = declaration.getAnnotationByClassId(throwsClassId, context.session)

        if (!checkInheritance(declaration, throwsAnnotation, context, reporter)) return

        if (throwsAnnotation.hasUnresolvedArgument()) return

        val classIds = throwsAnnotation?.getClassIds(context.session) ?: return

        if (classIds.isEmpty()) {
            reporter.reportOn(throwsAnnotation.source, FirNativeErrors.THROWS_LIST_EMPTY)
            return
        }

        if (declaration.hasModifier(KtTokens.SUSPEND_KEYWORD) && classIds.none { it in cancellationExceptionAndSupersClassIds }) {
            reporter.reportOn(
                throwsAnnotation.source,
                FirNativeErrors.MISSING_EXCEPTION_IN_THROWS_ON_SUSPEND,
                cancellationExceptionFqName
            )
        }
    }

    private fun checkInheritance(
        declaration: FirDeclaration,
        throwsAnnotation: FirAnnotation?,
        context: CheckerContext,
        reporter: DiagnosticReporter
    ): Boolean {
        if (declaration !is FirSimpleFunction) return true

        val inherited = getInheritedThrows(declaration, throwsAnnotation, context).entries.distinctBy { it.value }

        if (inherited.size >= 2) {
            reporter.reportOn(
                declaration.source,
                FirNativeErrors.INCOMPATIBLE_THROWS_INHERITED,
                inherited.mapNotNull { it.key.containingClassLookupTag()?.toRegularClassSymbol(context.session) },
                context
            )
            return false
        }

        val (overriddenMember, overriddenThrows) = inherited.firstOrNull()
            ?: return true // Should not happen though.

        if (throwsAnnotation?.source != null && decodeThrowsFilter(throwsAnnotation, context.session) != overriddenThrows) {
            val containingClassSymbol = overriddenMember.containingClassLookupTag()?.toRegularClassSymbol(context.session)
            if (containingClassSymbol != null) {
                reporter.reportOn(throwsAnnotation.source, FirNativeErrors.INCOMPATIBLE_THROWS_OVERRIDE, containingClassSymbol, context)
            }
            return false
        }

        return true
    }

    private fun getInheritedThrows(
        function: FirSimpleFunction,
        throwsAnnotation: FirAnnotation?,
        context: CheckerContext
    ): Map<FirNamedFunctionSymbol, ThrowsFilter> {
        val visited = mutableSetOf<FirNamedFunctionSymbol>()
        val result = mutableMapOf<FirNamedFunctionSymbol, ThrowsFilter>()

        fun getInheritedThrows(localThrowsAnnotation: FirAnnotation?, functionWithScope: MemberWithBaseScope<FirNamedFunctionSymbol>) {
            val localFunctionSymbol = functionWithScope.member
            if (!visited.add(localFunctionSymbol)) return

            val directOverriddenFunctionsWithScopes =
                @OptIn(ScopeFunctionRequiresPrewarm::class)
                functionWithScope.baseScope.getDirectOverriddenFunctionsWithBaseScope(localFunctionSymbol)

            if (localFunctionSymbol == function.symbol || localThrowsAnnotation == null && directOverriddenFunctionsWithScopes.isNotEmpty()) {
                for (directOverriddenFunctionWithScope in directOverriddenFunctionsWithScopes) {
                    val overriddenFunction = directOverriddenFunctionWithScope.member
                    val annotation = runUnless(overriddenFunction.isSubstitutionOrIntersectionOverride) {
                        overriddenFunction.getAnnotationByClassId(throwsClassId, context.session)
                    }
                    getInheritedThrows(annotation, directOverriddenFunctionWithScope)
                }
            } else {
                result[localFunctionSymbol] = decodeThrowsFilter(localThrowsAnnotation, context.session)
            }
        }

        val currentScope = function.symbol.containingClassLookupTag()?.toRegularClassSymbol(context.session)?.unsubstitutedScope(context)
        if (currentScope != null) {
            currentScope.processFunctionsByName(function.name) {}
            getInheritedThrows(throwsAnnotation, MemberWithBaseScope(function.symbol, currentScope))
        }

        return result
    }

    private fun FirElement?.hasUnresolvedArgument(): Boolean {
        if (this is FirWrappedArgumentExpression) {
            return expression.hasUnresolvedArgument()
        }

        if (this is FirResolvable && calleeReference.isError()) {
            return true
        }

        if (this is FirVarargArgumentsExpression) {
            for (argument in this.arguments) {
                if (argument.hasUnresolvedArgument()) {
                    return true
                }
            }
        }

        if (this is FirCall) {
            for (argument in this.argumentList.arguments) {
                if (argument.hasUnresolvedArgument()) {
                    return true
                }
            }
        }

        if (this is FirResolvedQualifier) {
            symbol?.let { symbol ->
                if (symbol is FirTypeAliasSymbol && symbol.resolvedExpandedTypeRef.coneType.hasError()) {
                    return true
                }
                // TODO: accept also FirClassSymbol<*>, like `FirClassLikeSymbol<*>.getSuperTypes()` does. Write test for this use-case.
            }
        }
        return false
    }

    private fun decodeThrowsFilter(throwsAnnotation: FirAnnotation?, session: FirSession): ThrowsFilter {
        return ThrowsFilter(throwsAnnotation?.getClassIds(session)?.toSet())
    }

    private fun FirAnnotation.getClassIds(session: FirSession): List<ClassId> {
        val unwrappedArgs = argumentMapping.mapping.values.firstOrNull()?.unwrapVarargValue() ?: return emptyList()
        return unwrappedArgs.mapNotNull { it.extractClassFromArgument(session)?.classId }
    }

    private data class ThrowsFilter(val classes: Set<ClassId>?)
}
