/*
 * Copyright 2010-2024 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package ksp.org.jetbrains.kotlin.analysis.low.level.api.fir.providers

import ksp.com.intellij.openapi.project.Project
import ksp.com.intellij.psi.search.GlobalSearchScope
import ksp.org.jetbrains.kotlin.analysis.api.platform.declarations.KotlinDirectInheritorsProvider
import ksp.org.jetbrains.kotlin.analysis.api.platform.projectStructure.KotlinModuleDependentsProvider
import ksp.org.jetbrains.kotlin.analysis.api.projectStructure.KaDanglingFileModule
import ksp.org.jetbrains.kotlin.analysis.api.projectStructure.baseContextModule
import ksp.org.jetbrains.kotlin.analysis.low.level.api.fir.projectStructure.llFirModuleData
import ksp.org.jetbrains.kotlin.analysis.low.level.api.fir.sessions.LLFirSessionCache
import ksp.org.jetbrains.kotlin.analysis.low.level.api.fir.symbolProviders.symbolProvider
import ksp.org.jetbrains.kotlin.fir.declarations.FirClass
import ksp.org.jetbrains.kotlin.fir.declarations.FirRegularClass
import ksp.org.jetbrains.kotlin.fir.declarations.SealedClassInheritorsProvider
import ksp.org.jetbrains.kotlin.fir.declarations.SealedClassInheritorsProviderInternals
import ksp.org.jetbrains.kotlin.fir.declarations.sealedInheritorsAttr
import ksp.org.jetbrains.kotlin.fir.declarations.utils.classId
import ksp.org.jetbrains.kotlin.fir.declarations.utils.isExpect
import ksp.org.jetbrains.kotlin.fir.psi
import ksp.org.jetbrains.kotlin.name.ClassId
import ksp.org.jetbrains.kotlin.psi.KtClass
import java.util.concurrent.ConcurrentHashMap
import kotlin.collections.filter
import kotlin.collections.map
import kotlin.collections.mapNotNull
import kotlin.collections.plus
import kotlin.collections.sortedBy
import kotlin.let

/**
 * [LLSealedInheritorsProvider] is the LL FIR implementation of [SealedClassInheritorsProvider] for both the IDE and Standalone mode.
 */
@OptIn(SealedClassInheritorsProviderInternals::class)
internal class LLSealedInheritorsProvider(private val project: Project) : SealedClassInheritorsProvider() {
    val cache = ConcurrentHashMap<ClassId, List<ClassId>>()

    override fun getSealedClassInheritors(firClass: FirRegularClass): List<ClassId> {
        // Classes from binary libraries which are deserialized from class files (but not stubs) will have their `sealedInheritorsAttr` set
        // from metadata.
        firClass.sealedInheritorsAttr?.let { return it.value }

        val classId = firClass.classId

        // Local classes cannot be sealed.
        if (classId.isLocal) {
            return emptyList()
        }

        return cache.computeIfAbsent(classId) { searchInheritors(firClass) }
    }

    /**
     * Some notes about the search:
     *
     *  - A Java class cannot legally extend a sealed Kotlin class (even in the same package), so we don't need to search for Java class
     *    inheritors.
     *  - Technically, we could use a package scope to narrow the search, but the search is already sufficiently narrow because it uses
     *    supertype indices and is confined to the current `KaModule` in most cases (except for 'expect' classes). Finding a `PsiPackage`
     *    for a `PackageScope` is not cheap, hence the decision to avoid it. If a `PackageScope` is needed in the future, it'd be best to
     *    extract a `PackageNameScope` which operates just with the qualified package name, to avoid `PsiPackage`. (At the time of writing,
     *    this is possible with the implementation of `PackageScope`.)
     *  - We ignore local classes to avoid lazy resolve contract violations.
     *    See KT-63795.
     *  - For `expect` declarations, the search scope includes all modules with a dependsOn dependency on the containing module.
     *    At the same time, `actual` declarations are restricted to the same module and require no special handling.
     *    See KT-45842.
     *  - KMP libraries are not yet supported.
     *    See KT-65591.
     */
    private fun searchInheritors(firClass: FirClass): List<ClassId> {
        val (targetModule, targetFirClass) = when (val classModule = firClass.llFirModuleData.ktModule) {
            is KaDanglingFileModule -> {
                // Since we are searching for inheritors in the context module's scope, we need to search for inheritors of the *original*
                // FIR class, not the dangling FIR class.
                val contextModule = classModule.baseContextModule
                val contextSession = LLFirSessionCache.getInstance(project).getSession(contextModule, preferBinary = true)
                val originalFirSymbol = contextSession.symbolProvider.getClassLikeSymbolByClassId(firClass.classId)
                val originalFirClass = originalFirSymbol?.fir as? FirClass ?: return emptyList()
                contextModule to originalFirClass
            }
            else -> classModule to firClass
        }
        val targetKtClass = targetFirClass.psi as? KtClass ?: return emptyList()

        // `FirClass.isExpect` does not depend on the `STATUS` phase because it's already set during FIR building.
        val scope = if (targetFirClass.isExpect) {
            val refinementDependents = KotlinModuleDependentsProvider.getInstance(project).getRefinementDependents(targetModule)
            GlobalSearchScope.union(refinementDependents.map { it.contentScope } + targetModule.contentScope)
        } else {
            targetModule.contentScope
        }

        return searchInScope(targetKtClass, targetFirClass.classId, scope)
    }

    private fun searchInScope(ktClass: KtClass, classId: ClassId, scope: GlobalSearchScope): List<ClassId> =
        KotlinDirectInheritorsProvider.getInstance(project)
            .getDirectKotlinInheritors(ktClass, scope, includeLocalInheritors = false)
            .mapNotNull { it.getClassId() }
            .filter { it.packageFqName == classId.packageFqName }
            // Enforce a deterministic order on the result, e.g. for stable test output.
            .sortedBy { it.toString() }
            .ifEmpty { emptyList() }
}
