/*
 * Decompiled with CFR 0.152.
 */
package jadx.core.dex.visitors;

import jadx.api.plugins.input.data.attributes.IJadxAttribute;
import jadx.core.clsp.ClspClass;
import jadx.core.clsp.ClspMethod;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.MethodOverrideAttr;
import jadx.core.dex.attributes.nodes.RenameReasonAttr;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.IMethodDetails;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.dex.visitors.JadxVisitor;
import jadx.core.dex.visitors.rename.RenameVisitor;
import jadx.core.dex.visitors.typeinference.TypeCompare;
import jadx.core.dex.visitors.typeinference.TypeCompareEnum;
import jadx.core.dex.visitors.typeinference.TypeInferenceVisitor;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

@JadxVisitor(name="OverrideMethodVisitor", desc="Mark override methods and revert type erasure", runBefore={TypeInferenceVisitor.class, RenameVisitor.class})
public class OverrideMethodVisitor
extends AbstractVisitor {
    @Override
    public boolean visit(ClassNode cls) throws JadxException {
        this.processCls(cls);
        return true;
    }

    private void processCls(ClassNode cls) {
        List<ArgType> superTypes = this.collectSuperTypes(cls);
        if (!superTypes.isEmpty()) {
            for (MethodNode mth : cls.getMethods()) {
                this.processMth(cls, superTypes, mth);
            }
        }
    }

    private void processMth(ClassNode cls, List<ArgType> superTypes, MethodNode mth) {
        if (mth.isConstructor() || mth.getAccessFlags().isStatic() || mth.getAccessFlags().isPrivate()) {
            return;
        }
        MethodOverrideAttr attr = this.processOverrideMethods(cls, mth, superTypes);
        if (attr != null) {
            mth.addAttr((IJadxAttribute)attr);
            IMethodDetails baseMth = Utils.last(attr.getOverrideList());
            if (baseMth != null) {
                boolean updated = this.fixMethodReturnType(mth, baseMth, superTypes);
                if ((updated |= this.fixMethodArgTypes(mth, baseMth, superTypes)) && cls.root().getArgs().isRenameValid()) {
                    this.fixMethodSignatureCollisions(mth);
                }
            }
        }
    }

    private MethodOverrideAttr processOverrideMethods(ClassNode cls, MethodNode mth, List<ArgType> superTypes) {
        MethodOverrideAttr result = mth.get(AType.METHOD_OVERRIDE);
        if (result != null) {
            return result;
        }
        String signature = mth.getMethodInfo().makeSignature(false);
        ArrayList<IMethodDetails> overrideList = new ArrayList<IMethodDetails>();
        for (ArgType superType : superTypes) {
            ClassNode classNode = cls.root().resolveClass(superType);
            if (classNode != null) {
                MethodNode ovrdMth = this.searchOverriddenMethod(classNode, signature);
                if (ovrdMth == null || !this.isMethodVisibleInCls(ovrdMth, cls)) continue;
                overrideList.add(ovrdMth);
                MethodOverrideAttr attr = ovrdMth.get(AType.METHOD_OVERRIDE);
                if (attr == null) continue;
                return this.buildOverrideAttr(mth, overrideList, attr);
            }
            ClspClass clsDetails = cls.root().getClsp().getClsDetails(superType);
            if (clsDetails == null) continue;
            Map<String, ClspMethod> methodsMap = clsDetails.getMethodsMap();
            for (Map.Entry<String, ClspMethod> entry : methodsMap.entrySet()) {
                String mthShortId = entry.getKey();
                if (!mthShortId.startsWith(signature)) continue;
                overrideList.add(entry.getValue());
            }
        }
        return this.buildOverrideAttr(mth, overrideList, null);
    }

    @Nullable
    private MethodNode searchOverriddenMethod(ClassNode cls, String signature) {
        for (MethodNode supMth : cls.getMethods()) {
            if (supMth.getAccessFlags().isStatic() || !supMth.getMethodInfo().getShortId().startsWith(signature)) continue;
            return supMth;
        }
        return null;
    }

    @Nullable
    private MethodOverrideAttr buildOverrideAttr(MethodNode mth, List<IMethodDetails> overrideList, @Nullable MethodOverrideAttr attr) {
        if (overrideList.isEmpty() && attr == null) {
            return null;
        }
        if (attr == null) {
            List<IMethodDetails> cleanOverrideList = overrideList.stream().distinct().collect(Collectors.toList());
            return this.applyOverrideAttr(mth, cleanOverrideList, false);
        }
        List<IMethodDetails> mergedOverrideList = Utils.mergeLists(overrideList, attr.getOverrideList());
        List<IMethodDetails> cleanOverrideList = mergedOverrideList.stream().distinct().collect(Collectors.toList());
        return this.applyOverrideAttr(mth, cleanOverrideList, true);
    }

    private MethodOverrideAttr applyOverrideAttr(MethodNode mth, List<IMethodDetails> overrideList, boolean update) {
        boolean dontRename = overrideList.stream().anyMatch(m -> !(m instanceof MethodNode));
        SortedSet<MethodNode> relatedMethods = null;
        List<MethodNode> mthNodes = this.getMethodNodes(mth, overrideList);
        if (update) {
            MethodOverrideAttr ovrdAttr;
            for (MethodNode mthNode : mthNodes) {
                ovrdAttr = mthNode.get(AType.METHOD_OVERRIDE);
                if (ovrdAttr == null) continue;
                relatedMethods = ovrdAttr.getRelatedMthNodes();
                break;
            }
            if (relatedMethods != null) {
                relatedMethods.addAll(mthNodes);
            } else {
                relatedMethods = new TreeSet<MethodNode>(mthNodes);
            }
            for (MethodNode mthNode : mthNodes) {
                SortedSet<MethodNode> set;
                ovrdAttr = mthNode.get(AType.METHOD_OVERRIDE);
                if (ovrdAttr == null || relatedMethods == (set = ovrdAttr.getRelatedMthNodes())) continue;
                relatedMethods.addAll(set);
            }
        } else {
            relatedMethods = new TreeSet<MethodNode>(mthNodes);
        }
        int depth = 0;
        for (MethodNode mthNode : mthNodes) {
            MethodOverrideAttr ovrdAttr;
            if (dontRename) {
                mthNode.add(AFlag.DONT_RENAME);
            }
            if (depth == 0) {
                depth = 1;
                continue;
            }
            if (update && (ovrdAttr = mthNode.get(AType.METHOD_OVERRIDE)) != null) {
                ovrdAttr.setRelatedMthNodes(relatedMethods);
                continue;
            }
            mthNode.addAttr((IJadxAttribute)new MethodOverrideAttr(Utils.listTail(overrideList, depth), relatedMethods));
            ++depth;
        }
        return new MethodOverrideAttr(overrideList, relatedMethods);
    }

    @NotNull
    private List<MethodNode> getMethodNodes(MethodNode mth, List<IMethodDetails> overrideList) {
        ArrayList<MethodNode> list = new ArrayList<MethodNode>(1 + overrideList.size());
        list.add(mth);
        for (IMethodDetails md : overrideList) {
            if (!(md instanceof MethodNode)) continue;
            list.add((MethodNode)md);
        }
        return list;
    }

    private boolean isMethodVisibleInCls(MethodNode superMth, ClassNode cls) {
        AccessInfo accessFlags = superMth.getAccessFlags();
        if (accessFlags.isPrivate()) {
            return false;
        }
        if (accessFlags.isPublic() || accessFlags.isProtected()) {
            return true;
        }
        return Objects.equals(superMth.getParentClass().getPackage(), cls.getPackage());
    }

    private List<ArgType> collectSuperTypes(ClassNode cls) {
        LinkedHashMap<String, ArgType> superTypes = new LinkedHashMap<String, ArgType>();
        this.collectSuperTypes(cls, superTypes);
        if (superTypes.isEmpty()) {
            return Collections.emptyList();
        }
        return new ArrayList<ArgType>(superTypes.values());
    }

    private void collectSuperTypes(ClassNode cls, Map<String, ArgType> superTypes) {
        RootNode root = cls.root();
        ArgType superClass = cls.getSuperClass();
        if (superClass != null && !Objects.equals(superClass, ArgType.OBJECT)) {
            this.addSuperType(root, superTypes, superClass);
        }
        for (ArgType iface : cls.getInterfaces()) {
            this.addSuperType(root, superTypes, iface);
        }
    }

    private void addSuperType(RootNode root, Map<String, ArgType> superTypesMap, ArgType superType) {
        superTypesMap.put(superType.getObject(), superType);
        ClassNode classNode = root.resolveClass(superType);
        if (classNode == null) {
            for (String superCls : root.getClsp().getSuperTypes(superType.getObject())) {
                ArgType type = ArgType.object(superCls);
                superTypesMap.put(type.getObject(), type);
            }
        } else {
            this.collectSuperTypes(classNode, superTypesMap);
        }
    }

    private boolean fixMethodReturnType(MethodNode mth, IMethodDetails baseMth, List<ArgType> superTypes) {
        ArgType returnType = mth.getReturnType();
        if (returnType == ArgType.VOID) {
            return false;
        }
        boolean updated = this.updateReturnType(mth, baseMth, superTypes);
        if (updated) {
            mth.addDebugComment("Return type fixed from '" + returnType + "' to match base method");
        }
        return updated;
    }

    private boolean updateReturnType(MethodNode mth, IMethodDetails baseMth, List<ArgType> superTypes) {
        ArgType baseReturnType = baseMth.getReturnType();
        if (mth.getReturnType().equals(baseReturnType)) {
            return false;
        }
        if (!baseReturnType.containsTypeVariable()) {
            return false;
        }
        TypeCompare typeCompare = mth.root().getTypeUpdate().getTypeCompare();
        ArgType baseCls = baseMth.getMethodInfo().getDeclClass().getType();
        for (ArgType superType : superTypes) {
            ArgType targetRetType;
            TypeCompareEnum compareResult = typeCompare.compareTypes(superType, baseCls);
            if (compareResult != TypeCompareEnum.NARROW_BY_GENERIC || (targetRetType = mth.root().getTypeUtils().replaceClassGenerics(superType, baseReturnType)) == null || targetRetType.containsTypeVariable() || targetRetType.equals(mth.getReturnType())) continue;
            mth.updateReturnType(targetRetType);
            return true;
        }
        return false;
    }

    private boolean fixMethodArgTypes(MethodNode mth, IMethodDetails baseMth, List<ArgType> superTypes) {
        List<ArgType> baseArgTypes;
        List<ArgType> mthArgTypes = mth.getArgTypes();
        if (mthArgTypes.equals(baseArgTypes = baseMth.getArgTypes())) {
            return false;
        }
        int argCount = mthArgTypes.size();
        if (argCount != baseArgTypes.size()) {
            return false;
        }
        boolean changed = false;
        ArrayList<ArgType> newArgTypes = new ArrayList<ArgType>(argCount);
        for (int argNum = 0; argNum < argCount; ++argNum) {
            ArgType newType = this.updateArgType(mth, baseMth, superTypes, argNum);
            if (newType != null) {
                changed = true;
                newArgTypes.add(newType);
                continue;
            }
            newArgTypes.add(mthArgTypes.get(argNum));
        }
        if (changed) {
            mth.updateArgTypes(newArgTypes, "Method arguments types fixed to match base method");
        }
        return changed;
    }

    private ArgType updateArgType(MethodNode mth, IMethodDetails baseMth, List<ArgType> superTypes, int argNum) {
        ArgType baseArg;
        ArgType arg = mth.getArgTypes().get(argNum);
        if (arg.equals(baseArg = baseMth.getArgTypes().get(argNum))) {
            return null;
        }
        if (!baseArg.containsTypeVariable()) {
            return null;
        }
        TypeCompare typeCompare = mth.root().getTypeUpdate().getTypeCompare();
        ArgType baseCls = baseMth.getMethodInfo().getDeclClass().getType();
        for (ArgType superType : superTypes) {
            ArgType targetArgType;
            TypeCompareEnum compareResult = typeCompare.compareTypes(superType, baseCls);
            if (compareResult != TypeCompareEnum.NARROW_BY_GENERIC || (targetArgType = mth.root().getTypeUtils().replaceClassGenerics(superType, baseArg)) == null || targetArgType.containsTypeVariable() || targetArgType.equals(arg)) continue;
            return targetArgType;
        }
        return null;
    }

    private void fixMethodSignatureCollisions(MethodNode mth) {
        String mthName = mth.getMethodInfo().getAlias();
        String newSignature = MethodInfo.makeShortId(mthName, mth.getArgTypes(), null);
        for (MethodNode otherMth : mth.getParentClass().getMethods()) {
            String otherSignature;
            String otherMthName = otherMth.getAlias();
            if (!otherMthName.equals(mthName) || otherMth == mth || !(otherSignature = otherMth.getMethodInfo().makeSignature(true, false)).equals(newSignature)) continue;
            if (otherMth.contains(AFlag.DONT_RENAME) || otherMth.contains(AType.METHOD_OVERRIDE)) {
                otherMth.addWarnComment("Can't rename method to resolve collision");
                continue;
            }
            otherMth.getMethodInfo().setAlias(OverrideMethodVisitor.makeNewAlias(otherMth));
            otherMth.addAttr(new RenameReasonAttr("avoid collision after fix types in other method"));
        }
    }

    private static String makeNewAlias(MethodNode mth) {
        ClassNode cls = mth.getParentClass();
        String baseName = mth.getAlias();
        int k = 2;
        String alias;
        MethodNode methodNode;
        while ((methodNode = cls.searchMethodByShortName(alias = baseName + k)) != null) {
            ++k;
        }
        return alias;
    }
}

