// Copyright 2000-2022 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.

package org.jetbrains.uast.kotlin.internal

import com.intellij.openapi.project.Project
import com.intellij.psi.*
import com.intellij.psi.util.PsiTypesUtil
import org.jetbrains.kotlin.analysis.api.*
import org.jetbrains.kotlin.analysis.api.annotations.*
import org.jetbrains.kotlin.analysis.api.calls.KtCallableMemberCall
import org.jetbrains.kotlin.analysis.api.components.buildClassType
import org.jetbrains.kotlin.analysis.api.lifetime.allowAnalysisFromWriteAction
import org.jetbrains.kotlin.analysis.api.lifetime.allowAnalysisOnEdt
import org.jetbrains.kotlin.analysis.api.symbols.*
import org.jetbrains.kotlin.analysis.api.symbols.markers.KtAnnotatedSymbol
import org.jetbrains.kotlin.analysis.api.types.*
import org.jetbrains.kotlin.analysis.project.structure.KtSourceModule
import org.jetbrains.kotlin.analysis.providers.DecompiledPsiDeclarationProvider.findPsi
import org.jetbrains.kotlin.asJava.*
import org.jetbrains.kotlin.asJava.classes.lazyPub
import org.jetbrains.kotlin.descriptors.annotations.AnnotationUseSiteTarget
import org.jetbrains.kotlin.descriptors.annotations.AnnotationUseSiteTarget.PROPERTY_GETTER
import org.jetbrains.kotlin.descriptors.annotations.AnnotationUseSiteTarget.PROPERTY_SETTER
import org.jetbrains.kotlin.idea.KotlinLanguage
import org.jetbrains.kotlin.name.JvmStandardClassIds
import org.jetbrains.kotlin.name.StandardClassIds
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.psi.psiUtil.containingClass
import org.jetbrains.uast.*
import org.jetbrains.uast.kotlin.*
import org.jetbrains.uast.kotlin.psi.UastFakeDeserializedSourceLightMethod
import org.jetbrains.uast.kotlin.psi.UastFakeDeserializedSymbolLightMethod
import org.jetbrains.uast.kotlin.psi.UastFakeSourceLightMethod
import org.jetbrains.uast.kotlin.psi.UastFakeSourceLightPrimaryConstructor

val firKotlinUastPlugin: FirKotlinUastLanguagePlugin by lazyPub {
    UastLanguagePlugin.getInstances().single { it.language == KotlinLanguage.INSTANCE } as FirKotlinUastLanguagePlugin?
        ?: FirKotlinUastLanguagePlugin()
}

@OptIn(KtAllowAnalysisOnEdt::class)
internal inline fun <R> analyzeForUast(
    useSiteKtElement: KtElement,
    action: KtAnalysisSession.() -> R
): R = allowAnalysisOnEdt {
    @OptIn(KtAllowAnalysisFromWriteAction::class)
    allowAnalysisFromWriteAction {
        analyze(useSiteKtElement, action)
    }
}

context(KtAnalysisSession)
internal fun containingKtClass(
    ktConstructorSymbol: KtConstructorSymbol,
): KtClass? {
    return when (val psi = ktConstructorSymbol.psi) {
        is KtClass -> psi
        is KtConstructor<*> -> psi.containingClass()
        else -> null
    }
}

context(KtAnalysisSession)
internal fun toPsiClass(
    ktType: KtType,
    source: UElement?,
    context: KtElement,
    typeOwnerKind: TypeOwnerKind,
    isBoxed: Boolean = true,
): PsiClass? {
    (context as? KtClass)?.toLightClass()?.let { return it }
    return PsiTypesUtil.getPsiClass(
        toPsiType(
            ktType,
            source,
            context,
            PsiTypeConversionConfiguration(typeOwnerKind, isBoxed = isBoxed)
        )
    )
}

context(KtAnalysisSession)
internal fun toPsiMethod(
    functionSymbol: KtFunctionLikeSymbol,
    context: KtElement,
): PsiMethod? {
    // `inline` w/ `reified` type param from binary dependency,
    // which we can't find source PSI, so fake it
    if (functionSymbol.origin == KtSymbolOrigin.LIBRARY &&
        (functionSymbol as? KtFunctionSymbol)?.isInline == true &&
        functionSymbol.typeParameters.any { it.isReified }
    ) {
        functionSymbol.getContainingJvmClassName()?.let { fqName ->
            JavaPsiFacade.getInstance(context.project)
                .findClass(fqName, context.resolveScope)
                ?.let { containingClass ->
                    return UastFakeDeserializedSymbolLightMethod(
                        functionSymbol.createPointer(),
                        functionSymbol.name.identifier,
                        containingClass,
                        context
                    )
                }
        }
    }
    return when (val psi = psiForUast(functionSymbol, context.project)) {
        null -> {
            // Lint/UAST CLI: try `fake` creation for a deserialized declaration
            toPsiMethodForDeserialized(functionSymbol, context, psi)
        }
        is PsiMethod -> psi
        is KtClassOrObject -> {
            // For synthetic members in enum classes, `psi` points to their containing enum class.
            if (psi is KtClass && psi.isEnum()) {
                val lc = psi.toLightClass() ?: return null
                lc.methods.find { it.name == (functionSymbol as? KtFunctionSymbol)?.name?.identifier }?.let { return it }
            }

            // Default primary constructor
            psi.primaryConstructor?.getRepresentativeLightMethod()?.let { return it }
            val lc = psi.toLightClass() ?: return null
            lc.constructors.firstOrNull()?.let { return it }
            if (psi.isLocal) UastFakeSourceLightPrimaryConstructor(psi, lc) else null
        }
        is KtFunction -> {
            // For JVM-invisible methods, such as @JvmSynthetic, LC conversion returns nothing, so fake it
            fun handleLocalOrSynthetic(source: KtFunction): PsiMethod? {
                val ktModule = getModule(source)
                if (ktModule !is KtSourceModule) return null
                return getContainingLightClass(source)?.let { UastFakeSourceLightMethod(source, it) }
            }

            when {
                psi.isLocal ->
                    handleLocalOrSynthetic(psi)
                functionSymbol.unwrapFakeOverrides.origin == KtSymbolOrigin.LIBRARY ->
                    // PSI to regular libraries should be handled by [DecompiledPsiDeclarationProvider]
                    // That is, this one is a deserialized declaration.
                    toPsiMethodForDeserialized(functionSymbol, context, psi)
                else ->
                    psi.getRepresentativeLightMethod()
                        ?: handleLocalOrSynthetic(psi)
            }
        }
        else -> psi.getRepresentativeLightMethod()
    }
}

context(KtAnalysisSession)
private fun toPsiMethodForDeserialized(
    functionSymbol: KtFunctionLikeSymbol,
    context: KtElement,
    psi: KtFunction?,
): PsiMethod? {

    fun equalSignatures(psiMethod: PsiMethod): Boolean {
        val methodParameters: Array<PsiParameter> = psiMethod.parameterList.parameters
        val symbolParameters: List<KtValueParameterSymbol> = functionSymbol.valueParameters
        if (methodParameters.size != symbolParameters.size) {
            return false
        }

        for (i in methodParameters.indices) {
            val symbolParameter = symbolParameters[i]
            val symbolParameterType = toPsiType(
                symbolParameter.returnType,
                psiMethod,
                context,
                PsiTypeConversionConfiguration(
                    TypeOwnerKind.DECLARATION,
                    typeMappingMode = KtTypeMappingMode.VALUE_PARAMETER,
                )
            )

            if (methodParameters[i].type != symbolParameterType) return false
        }
        val psiMethodReturnType = psiMethod.returnType ?: PsiTypes.voidType()
        val symbolReturnType = toPsiType(
            functionSymbol.returnType,
            psiMethod,
            context,
            PsiTypeConversionConfiguration(
                TypeOwnerKind.DECLARATION,
                typeMappingMode = KtTypeMappingMode.RETURN_TYPE,
            )
        )

        return psiMethodReturnType == symbolReturnType
    }

    fun PsiClass.lookup(): PsiMethod? {
        val candidates =
            if (functionSymbol is KtConstructorSymbol)
                constructors.filter { it.parameterList.parameters.size == functionSymbol.valueParameters.size }
            else {
                val jvmName = when (functionSymbol) {
                    is KtPropertyGetterSymbol -> {
                        functionSymbol.getJvmNameFromAnnotation(PROPERTY_GETTER.toOptionalFilter())
                    }
                    is KtPropertySetterSymbol -> {
                        functionSymbol.getJvmNameFromAnnotation(PROPERTY_SETTER.toOptionalFilter())
                    }
                    else -> {
                        functionSymbol.getJvmNameFromAnnotation()
                    }
                }
                val id = jvmName
                    ?: functionSymbol.callableIdIfNonLocal?.callableName?.identifierOrNullIfSpecial
                    ?: psi?.name
                methods.filter { it.name == id }
            }

        return when (candidates.size) {
            0 -> {
                if (psi != null) {
                    UastFakeDeserializedSourceLightMethod(psi, this@lookup)
                } else if (functionSymbol is KtFunctionSymbol) {
                    UastFakeDeserializedSymbolLightMethod(
                        functionSymbol.createPointer(),
                        functionSymbol.name.identifier,
                        this@lookup,
                        context
                    )
                } else null
            }
            1 -> {
                candidates.single()
            }
            else -> {
                candidates.firstOrNull { equalSignatures(it) } ?: candidates.first()
            }
        }
    }

    // Deserialized member function
    val classId = psi?.containingClass()?.getClassId()
        ?: functionSymbol.callableIdIfNonLocal?.classId
    if (classId != null) {
        toPsiClass(
            buildClassType(classId),
            source = null,
            context,
            TypeOwnerKind.DECLARATION,
        )?.lookup()?.let { return it }
    }
    // Deserialized top-level function
    return if (psi != null) {
        // Lint/UAST IDE: with deserialized PSI
        psi.containingKtFile.findFacadeClass()?.lookup()
    } else if (functionSymbol is KtFunctionSymbol) {
        // Lint/UAST CLI: attempt to find the binary class
        //   with the facade fq name from the resolved symbol
        functionSymbol.getContainingJvmClassName()?.let { fqName ->
            JavaPsiFacade.getInstance(context.project)
                .findClass(fqName, context.resolveScope)
                ?.lookup()
        }
    } else null
}

private fun KtAnnotatedSymbol.getJvmNameFromAnnotation(
    useSiteTargetFilter: AnnotationUseSiteTargetFilter = AnyAnnotationUseSiteTargetFilter,
): String? {
    val anno = annotationsByClassId(JvmStandardClassIds.JVM_NAME_CLASS_ID, useSiteTargetFilter).firstOrNull() ?: return null
    return (anno.arguments.firstOrNull()?.expression as? KtConstantAnnotationValue)?.constantValue?.value as? String
}

private fun AnnotationUseSiteTarget.toOptionalFilter(): AnnotationUseSiteTargetFilter {
    return annotationUseSiteTargetFilterOf(NoAnnotationUseSiteTargetFilter, toFilter())
}

private fun annotationUseSiteTargetFilterOf(
    vararg filters: AnnotationUseSiteTargetFilter,
): AnnotationUseSiteTargetFilter = AnnotationUseSiteTargetFilter { useSiteTarget ->
    filters.any { filter -> filter.isAllowed(useSiteTarget) }
}

context(KtAnalysisSession)
internal fun toPsiType(
    ktType: KtType,
    source: UElement?,
    context: KtElement,
    config: PsiTypeConversionConfiguration,
): PsiType =
    toPsiType(
        ktType,
        source?.getParentOfType<UDeclaration>(false)?.javaPsi as? PsiModifierListOwner,
        context,
        config
    )

context(KtAnalysisSession)
internal fun toPsiType(
    ktType: KtType,
    containingLightDeclaration: PsiModifierListOwner?,
    context: KtElement,
    config: PsiTypeConversionConfiguration,
): PsiType {
    if (ktType is KtNonErrorClassType && ktType.ownTypeArguments.isEmpty()) {
        fun PsiPrimitiveType.orBoxed() = if (config.isBoxed) getBoxedType(context) else this
        val psiType = when (ktType.classId) {
            StandardClassIds.Int -> PsiTypes.intType().orBoxed()
            StandardClassIds.Long -> PsiTypes.longType().orBoxed()
            StandardClassIds.Short -> PsiTypes.shortType().orBoxed()
            StandardClassIds.Boolean -> PsiTypes.booleanType().orBoxed()
            StandardClassIds.Byte -> PsiTypes.byteType().orBoxed()
            StandardClassIds.Char -> PsiTypes.charType().orBoxed()
            StandardClassIds.Double -> PsiTypes.doubleType().orBoxed()
            StandardClassIds.Float -> PsiTypes.floatType().orBoxed()
            StandardClassIds.Unit -> convertUnitToVoidIfNeeded(context, config.typeOwnerKind, config.isBoxed)
            StandardClassIds.String -> PsiType.getJavaLangString(context.manager, context.resolveScope)
            else -> null
        }
        if (psiType != null) return psiType
    }
    val psiTypeParent: PsiElement = containingLightDeclaration ?: context
    return ktType.asPsiType(
        psiTypeParent,
        allowErrorTypes = false,
        config.typeMappingMode,
        isAnnotationMethod = false
    ) ?: UastErrorType
}

context(KtAnalysisSession)
internal fun receiverType(
    ktCall: KtCallableMemberCall<*, *>,
    source: UElement,
    context: KtElement,
): PsiType? {
    val ktType = ktCall.partiallyAppliedSymbol.signature.receiverType
        ?: ktCall.partiallyAppliedSymbol.extensionReceiver?.type
        ?: ktCall.partiallyAppliedSymbol.dispatchReceiver?.type
    if (ktType == null || ktType is KtErrorType) return null
    return toPsiType(
        ktType,
        source,
        context,
        PsiTypeConversionConfiguration.create(
            context,
            isBoxed = true,
        )
    )
}

context(KtAnalysisSession)
internal val KtType.typeForValueClass: Boolean
    get() {
        val symbol = expandedClassSymbol as? KtNamedClassOrObjectSymbol ?: return false
        return symbol.isInline
    }

context(KtAnalysisSession)
internal fun isInheritedGenericType(ktType: KtType?): Boolean {
    if (ktType == null) return false
    return ktType is KtTypeParameterType &&
        // explicitly nullable, e.g., T?
        !ktType.isMarkedNullable &&
        // non-null upper bound, e.g., T : Any
        nullability(ktType) != KtTypeNullability.NON_NULLABLE
}

context(KtAnalysisSession)
internal fun nullability(ktType: KtType?): KtTypeNullability? {
    if (ktType == null) return null
    if (ktType is KtErrorType) return null
    return if (ktType.fullyExpandedType.canBeNull)
        KtTypeNullability.NULLABLE
    else
        KtTypeNullability.NON_NULLABLE
}

context(KtAnalysisSession)
internal fun getKtType(ktCallableDeclaration: KtCallableDeclaration): KtType? {
    return (ktCallableDeclaration.getSymbol() as? KtCallableSymbol)?.returnType
}

/**
 * Finds Java stub-based [PsiElement] for symbols that refer to declarations in [KtLibraryModule].
 */
context(KtAnalysisSession)
internal tailrec fun psiForUast(symbol: KtSymbol, project: Project): PsiElement? {
    if (symbol.origin == KtSymbolOrigin.LIBRARY) {
        // UAST/Lint CLI: use [DecompiledPsiDeclarationProvider] / [KotlinStaticPsiDeclarationFromBinaryModuleProvider]
        return findPsi(symbol, project)
            // UAST/Lint IDE: decompiled PSI
            ?: symbol.psi
    }

    if (symbol is KtCallableSymbol) {
        if (symbol.origin == KtSymbolOrigin.INTERSECTION_OVERRIDE || symbol.origin == KtSymbolOrigin.SUBSTITUTION_OVERRIDE) {
            val originalSymbol = symbol.unwrapFakeOverrides
            if (originalSymbol !== symbol) {
                return psiForUast(originalSymbol, project)
            }
        }
    }

    return symbol.psi
}

internal fun KtElement.toPsiElementAsLightElement(
    sourcePsi: KtExpression? = null
): PsiElement? {
    if (this is KtProperty) {
        with(getAccessorLightMethods()) {
            // Weigh [PsiField]
            backingField?.let { return it }
            val readWriteAccess = sourcePsi?.readWriteAccess()
            when {
                readWriteAccess?.isWrite == true -> {
                    setter?.let { return it }
                }
                readWriteAccess?.isRead == true -> {
                    getter?.let { return it }
                }
                else -> {}
            }
        }
    }
    return toLightElements().firstOrNull()
}
