/*
 * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package ksp.org.jetbrains.kotlin.fir.analysis.checkers.declaration

import ksp.org.jetbrains.kotlin.KtRealSourceElementKind
import ksp.org.jetbrains.kotlin.descriptors.ClassKind
import ksp.org.jetbrains.kotlin.diagnostics.DiagnosticReporter
import ksp.org.jetbrains.kotlin.diagnostics.reportOn
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.*
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
import ksp.org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors
import ksp.org.jetbrains.kotlin.fir.declarations.*
import ksp.org.jetbrains.kotlin.fir.declarations.utils.isOverride
import ksp.org.jetbrains.kotlin.fir.resolve.toRegularClassSymbol
import ksp.org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
import ksp.org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import ksp.org.jetbrains.kotlin.fir.symbols.impl.FirConstructorSymbol
import ksp.org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
import ksp.org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
import ksp.org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
import ksp.org.jetbrains.kotlin.fir.symbols.impl.FirTypeAliasSymbol
import ksp.org.jetbrains.kotlin.fir.types.*
import ksp.org.jetbrains.kotlin.types.model.KotlinTypeMarker
import ksp.org.jetbrains.kotlin.types.model.TypeCheckerProviderContext

sealed class FirTypeParameterBoundsChecker(mppKind: MppCheckerKind) : FirTypeParameterChecker(mppKind) {
    object Regular : FirTypeParameterBoundsChecker(MppCheckerKind.Platform) {
        context(context: CheckerContext, reporter: DiagnosticReporter)
        override fun check(declaration: FirTypeParameter) {
            val containingDeclaration = context.containingDeclarations.lastOrNull() ?: return
            if (containingDeclaration.isExpect()) return
            check(declaration, containingDeclaration, context, reporter)
        }
    }

    object ForExpectClass : FirTypeParameterBoundsChecker(MppCheckerKind.Common) {
        context(context: CheckerContext, reporter: DiagnosticReporter)
        override fun check(declaration: FirTypeParameter) {
            val containingDeclaration = context.containingDeclarations.lastOrNull() ?: return
            if (!containingDeclaration.isExpect()) return
            check(declaration, containingDeclaration, context, reporter)
        }
    }

    private val classKinds = setOf(
        ClassKind.CLASS,
        ClassKind.ENUM_CLASS,
        ClassKind.OBJECT
    )

    protected fun check(
        declaration: FirTypeParameter,
        containingDeclaration: FirBasedSymbol<*>,
        context: CheckerContext,
        reporter: DiagnosticReporter,
    ) {
        if (containingDeclaration is FirConstructorSymbol) return

        checkFinalUpperBounds(declaration, containingDeclaration, context, reporter)
        checkExtensionFunctionTypeBound(declaration, context, reporter)

        if ((containingDeclaration as? FirCallableSymbol)?.isInlineOnly(context.session) != true) {
            checkOnlyOneTypeParameterBound(declaration, context, reporter)
        }

        checkBoundUniqueness(declaration, context, reporter)
        checkConflictingBounds(declaration, context, reporter)
        checkTypeAliasBound(declaration, containingDeclaration, context, reporter)
        checkDynamicBounds(declaration, context, reporter)
        checkInconsistentTypeParameterBounds(declaration, context, reporter)
    }

    private fun checkFinalUpperBounds(
        declaration: FirTypeParameter,
        containingDeclaration: FirBasedSymbol<*>,
        context: CheckerContext,
        reporter: DiagnosticReporter
    ) {
        if (containingDeclaration is FirNamedFunctionSymbol && containingDeclaration.isOverride) return
        if (containingDeclaration is FirPropertySymbol && containingDeclaration.isOverride) return

        declaration.symbol.resolvedBounds.forEach { bound ->
            val boundType = bound.coneType
            // DYNAMIC_UPPER_BOUND will be reported separately
            if (boundType is ConeDynamicType) return@forEach
            if (!boundType.canHaveSubtypesAccordingToK1(context.session)) {
                reporter.reportOn(bound.source, FirErrors.FINAL_UPPER_BOUND, bound.coneType, context)
            }
        }
    }

    private fun checkExtensionFunctionTypeBound(declaration: FirTypeParameter, context: CheckerContext, reporter: DiagnosticReporter) {
        declaration.symbol.resolvedBounds.forEach { bound ->
            if (bound.isExtensionFunctionType(context.session)) {
                reporter.reportOn(bound.source, FirErrors.UPPER_BOUND_IS_EXTENSION_FUNCTION_TYPE, context)
            }
        }
    }

    private fun checkTypeAliasBound(
        declaration: FirTypeParameter,
        containingDeclaration: FirBasedSymbol<*>,
        context: CheckerContext,
        reporter: DiagnosticReporter
    ) {
        if (containingDeclaration is FirTypeAliasSymbol) {
            declaration.bounds.filter { it.source?.kind == KtRealSourceElementKind }.forEach { bound ->
                reporter.reportOn(bound.source, FirErrors.BOUND_ON_TYPE_ALIAS_PARAMETER_NOT_ALLOWED, context)
            }
        }
    }

    private fun checkOnlyOneTypeParameterBound(declaration: FirTypeParameter, context: CheckerContext, reporter: DiagnosticReporter) {
        val bounds = declaration.symbol.resolvedBounds.distinctBy { it.coneType }
        val (boundWithParam, otherBounds) = bounds.partition { it.coneType is ConeTypeParameterType }
        if (boundWithParam.size > 1 || (boundWithParam.size == 1 && otherBounds.isNotEmpty())) {
            // If there's only one problematic bound (either 2 type parameter bounds, or 1 type parameter bound + 1 other bound),
            // report the diagnostic on that bound

            //take TypeConstraint bounds only to report on the same point as old FE
            val constraintBounds = with(SourceNavigator.forElement(declaration)) {
                bounds.filter { it.isInTypeConstraint() }.toSet()
            }
            val reportOn =
                if (bounds.size == 2) {
                    val boundDecl = otherBounds.firstOrNull() ?: boundWithParam.last()
                    if (constraintBounds.contains(boundDecl)) boundDecl.source
                    else declaration.source
                } else {
                    declaration.source
                }
            reporter.reportOn(reportOn, FirErrors.BOUNDS_NOT_ALLOWED_IF_BOUNDED_BY_TYPE_PARAMETER, context)
        }
    }

    private fun checkBoundUniqueness(declaration: FirTypeParameter, context: CheckerContext, reporter: DiagnosticReporter) {
        val seenClasses = mutableSetOf<FirRegularClassSymbol>()
        val allNonErrorBounds = declaration.symbol.resolvedBounds.filter { it !is FirErrorTypeRef }
        val uniqueBounds = allNonErrorBounds.distinctBy { it.coneType.fullyExpandedClassId(context.session) ?: it.coneType }

        uniqueBounds.forEach { bound ->
            bound.coneType.toRegularClassSymbol(context.session)?.let { symbol ->
                if (classKinds.contains(symbol.classKind) && seenClasses.add(symbol) && seenClasses.size > 1) {
                    reporter.reportOn(bound.source, FirErrors.ONLY_ONE_CLASS_BOUND_ALLOWED, context)
                }
            }
        }

        allNonErrorBounds.minus(uniqueBounds).forEach { bound ->
            reporter.reportOn(bound.source, FirErrors.REPEATED_BOUND, context)
        }
    }

    private fun checkConflictingBounds(declaration: FirTypeParameter, context: CheckerContext, reporter: DiagnosticReporter) {
        fun anyConflictingTypes(types: List<ConeKotlinType>): Boolean {
            types.forEach { type ->
                if (!type.canHaveSubtypesAccordingToK1(context.session)) {
                    types.forEach { otherType ->
                        if (type != otherType && !type.isRelated(context.session.typeContext, otherType)) {
                            return true
                        }
                    }
                }
            }
            return false
        }

        if (declaration.bounds.size >= 2 && anyConflictingTypes(declaration.symbol.resolvedBounds.map { it.coneType })) {
            reporter.reportOn(declaration.source, FirErrors.CONFLICTING_UPPER_BOUNDS, declaration.symbol, context)
        }
    }

    private fun checkDynamicBounds(declaration: FirTypeParameter, context: CheckerContext, reporter: DiagnosticReporter) {
        declaration.bounds.forEach { bound ->
            if (bound.coneType is ConeDynamicType) {
                reporter.reportOn(bound.source, FirErrors.DYNAMIC_UPPER_BOUND, context)
            }
        }
    }

    private fun KotlinTypeMarker.isRelated(context: TypeCheckerProviderContext, type: KotlinTypeMarker?): Boolean =
        isSubtypeOf(context, type) || isSupertypeOf(context, type)

    private fun checkInconsistentTypeParameterBounds(
        declaration: FirTypeParameter,
        context: CheckerContext,
        reporter: DiagnosticReporter
    ) {
        if (declaration.bounds.size <= 1) return

        val firTypeRefClasses = mutableListOf<Pair<FirTypeRef, FirRegularClassSymbol>>()
        val firRegularClassesSet = mutableSetOf<FirRegularClassSymbol>()

        for (bound in declaration.symbol.resolvedBounds) {
            val classSymbol = bound.toRegularClassSymbol(context.session) ?: continue
            if (!firRegularClassesSet.add(classSymbol)) {
                // no need to report INCONSISTENT_TYPE_PARAMETER_BOUNDS because REPEATED_BOUNDS has already been reported
                return
            }

            firTypeRefClasses.add(bound to classSymbol)
        }

        checkInconsistentTypeParameters(firTypeRefClasses, context, reporter, declaration.source, isValues = false)
    }
}
