/*
 * Decompiled with CFR 0.152.
 */
package me.coley.cafedude.transform;

import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import me.coley.cafedude.classfile.ClassFile;
import me.coley.cafedude.classfile.Descriptor;
import me.coley.cafedude.classfile.Field;
import me.coley.cafedude.classfile.Method;
import me.coley.cafedude.classfile.annotation.Annotation;
import me.coley.cafedude.classfile.annotation.ClassElementValue;
import me.coley.cafedude.classfile.annotation.ElementValue;
import me.coley.cafedude.classfile.annotation.EnumElementValue;
import me.coley.cafedude.classfile.annotation.PrimitiveElementValue;
import me.coley.cafedude.classfile.annotation.TargetInfo;
import me.coley.cafedude.classfile.annotation.TypeAnnotation;
import me.coley.cafedude.classfile.annotation.Utf8ElementValue;
import me.coley.cafedude.classfile.attribute.AnnotationDefaultAttribute;
import me.coley.cafedude.classfile.attribute.AnnotationsAttribute;
import me.coley.cafedude.classfile.attribute.Attribute;
import me.coley.cafedude.classfile.attribute.AttributeContexts;
import me.coley.cafedude.classfile.attribute.BootstrapMethodsAttribute;
import me.coley.cafedude.classfile.attribute.CodeAttribute;
import me.coley.cafedude.classfile.attribute.ConstantValueAttribute;
import me.coley.cafedude.classfile.attribute.DefaultAttribute;
import me.coley.cafedude.classfile.attribute.EnclosingMethodAttribute;
import me.coley.cafedude.classfile.attribute.ExceptionsAttribute;
import me.coley.cafedude.classfile.attribute.InnerClassesAttribute;
import me.coley.cafedude.classfile.attribute.LocalVariableTableAttribute;
import me.coley.cafedude.classfile.attribute.LocalVariableTypeTableAttribute;
import me.coley.cafedude.classfile.attribute.ModuleAttribute;
import me.coley.cafedude.classfile.attribute.NestHostAttribute;
import me.coley.cafedude.classfile.attribute.NestMembersAttribute;
import me.coley.cafedude.classfile.attribute.ParameterAnnotationsAttribute;
import me.coley.cafedude.classfile.attribute.PermittedClassesAttribute;
import me.coley.cafedude.classfile.attribute.RecordAttribute;
import me.coley.cafedude.classfile.attribute.SignatureAttribute;
import me.coley.cafedude.classfile.attribute.SourceFileAttribute;
import me.coley.cafedude.classfile.behavior.AttributeHolder;
import me.coley.cafedude.classfile.constant.ConstPoolEntry;
import me.coley.cafedude.classfile.constant.CpClass;
import me.coley.cafedude.classfile.constant.CpInt;
import me.coley.cafedude.classfile.constant.CpUtf8;
import me.coley.cafedude.io.AttributeContext;
import me.coley.cafedude.transform.Transformer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class IllegalStrippingTransformer
extends Transformer {
    private static final int FORCE_FAIL = -1;
    private static final Logger logger = LoggerFactory.getLogger(IllegalStrippingTransformer.class);

    public IllegalStrippingTransformer(ClassFile clazz) {
        super(clazz);
    }

    @Override
    public void transform() {
        logger.info("Transforming '{}'", (Object)this.clazz.getName());
        Set<Integer> cpAccesses = this.clazz.cpAccesses();
        this.clazz.getAttributes().removeIf(attribute -> !this.isValidWrapped(this.clazz, (Attribute)attribute));
        for (Field field : this.clazz.getFields()) {
            field.getAttributes().removeIf(attribute -> !this.isValidWrapped(field, (Attribute)attribute));
        }
        for (Method method : this.clazz.getMethods()) {
            method.getAttributes().removeIf(attribute -> !this.isValidWrapped(method, (Attribute)attribute));
        }
        Set<Integer> filteredCpAccesses = this.clazz.cpAccesses();
        cpAccesses.removeAll(filteredCpAccesses);
        int max = this.pool.size();
        for (int index : cpAccesses) {
            if (index == 0 || index >= max - 1) continue;
            ConstPoolEntry cpe = this.pool.get(index);
            switch (cpe.getTag()) {
                case 17: 
                case 18: {
                    logger.debug("Removing now unused CP entry: {}={}", (Object)index, (Object)cpe.getClass().getSimpleName());
                    this.pool.set(index, new CpInt(0));
                    break;
                }
            }
        }
    }

    private boolean isValidWrapped(AttributeHolder holder, Attribute attribute) {
        try {
            return this.isValid(holder, attribute);
        }
        catch (Exception ex) {
            logger.warn("Encountered exception when parsing attribute '{}' in context '{}', dropping it", (Object)attribute.getClass().getName(), (Object)holder.getHolderType().name());
            return false;
        }
    }

    private boolean isValid(AttributeHolder holder, Attribute attribute) {
        AttributeContext context;
        HashMap<Integer, Predicate<Integer>> expectedTypeMasks = new HashMap<Integer, Predicate<Integer>>();
        HashMap<Integer, Predicate<ConstPoolEntry>> cpEntryValidators = new HashMap<Integer, Predicate<ConstPoolEntry>>();
        int maxCpIndex = this.pool.size();
        if (attribute.getNameIndex() > maxCpIndex) {
            return false;
        }
        if (attribute instanceof DefaultAttribute) {
            return true;
        }
        String name = this.pool.getUtf(attribute.getNameIndex());
        Collection<AttributeContext> allowedContexts = AttributeContexts.getAllowedContexts(name);
        if (!allowedContexts.contains((Object)(context = holder.getHolderType()))) {
            logger.debug("Found '{}' declared in illegal context {}, allowed contexts: {}", new Object[]{name, context.name(), allowedContexts});
            return false;
        }
        boolean allow0Case = false;
        switch (name) {
            case "ConstantValue": {
                int valueIndex = ((ConstantValueAttribute)attribute).getConstantValueIndex();
                expectedTypeMasks.put(valueIndex, i -> i >= 3 && i <= 8);
                break;
            }
            case "RuntimeInvisibleAnnotations": 
            case "RuntimeVisibleAnnotations": 
            case "RuntimeInvisibleTypeAnnotations": 
            case "RuntimeVisibleTypeAnnotations": {
                AnnotationsAttribute annotations = (AnnotationsAttribute)attribute;
                for (Annotation anno : annotations.getAnnotations()) {
                    this.addAnnotationValidation(holder, expectedTypeMasks, cpEntryValidators, anno);
                }
                break;
            }
            case "RuntimeInvisibleParameterAnnotations": 
            case "RuntimeVisibleParameterAnnotations": {
                Object annotationList;
                if (context != AttributeContext.METHOD) {
                    return false;
                }
                ParameterAnnotationsAttribute paramAnnotations = (ParameterAnnotationsAttribute)attribute;
                Method method = (Method)holder;
                String desc = this.pool.getUtf(method.getTypeIndex());
                int parameterCount = Descriptor.from(desc).getParameterCount();
                if (paramAnnotations.getParameterAnnotations().keySet().stream().anyMatch(key -> key >= parameterCount)) {
                    String methodName = this.pool.getUtf(method.getNameIndex());
                    logger.debug("Out of bounds parameter-annotation indices used on method {}", (Object)methodName);
                    return false;
                }
                Object parameterAnnos = paramAnnotations.getParameterAnnotations().values();
                Iterator iterator = parameterAnnos.iterator();
                while (iterator.hasNext()) {
                    annotationList = (List)iterator.next();
                    Iterator iterator2 = annotationList.iterator();
                    while (iterator2.hasNext()) {
                        Annotation anno = (Annotation)iterator2.next();
                        this.addAnnotationValidation(holder, expectedTypeMasks, cpEntryValidators, anno);
                    }
                }
                break;
            }
            case "AnnotationDefault": {
                AnnotationDefaultAttribute annotationDefault = (AnnotationDefaultAttribute)attribute;
                ElementValue elementValue = annotationDefault.getElementValue();
                this.addElementValueValidation(holder, expectedTypeMasks, cpEntryValidators, elementValue);
                break;
            }
            case "NestHost": {
                NestHostAttribute nestHost = (NestHostAttribute)attribute;
                expectedTypeMasks.put(nestHost.getHostClassIndex(), i -> i == 7);
                cpEntryValidators.put(nestHost.getHostClassIndex(), this.matchClassType());
                break;
            }
            case "NestMembers": {
                NestMembersAttribute nestMembers = (NestMembersAttribute)attribute;
                Object parameterAnnos = nestMembers.getMemberClassIndices().iterator();
                while (parameterAnnos.hasNext()) {
                    int memberIndex = parameterAnnos.next();
                    expectedTypeMasks.put(memberIndex, i -> i == 7);
                    cpEntryValidators.put(memberIndex, this.matchClassType());
                }
                break;
            }
            case "EnclosingMethod": {
                EnclosingMethodAttribute enclosingMethod = (EnclosingMethodAttribute)attribute;
                expectedTypeMasks.put(enclosingMethod.getClassIndex(), i -> i == 7);
                cpEntryValidators.put(enclosingMethod.getClassIndex(), this.matchClassType());
                expectedTypeMasks.put(enclosingMethod.getMethodIndex(), i -> i == 0 || i == 12);
                allow0Case = enclosingMethod.getMethodIndex() == 0;
                break;
            }
            case "Exceptions": {
                ExceptionsAttribute exceptions = (ExceptionsAttribute)attribute;
                Object annotationList = exceptions.getExceptionIndexTable().iterator();
                while (annotationList.hasNext()) {
                    int exceptionTypeIndex = annotationList.next();
                    expectedTypeMasks.put(exceptionTypeIndex, i -> i == 7);
                    cpEntryValidators.put(exceptionTypeIndex, this.matchClassType());
                }
                break;
            }
            case "InnerClasses": {
                InnerClassesAttribute innerClasses = (InnerClassesAttribute)attribute;
                for (InnerClassesAttribute.InnerClass innerClass : innerClasses.getInnerClasses()) {
                    expectedTypeMasks.put(innerClass.getInnerClassInfoIndex(), i -> i == 0 || i == 7);
                    cpEntryValidators.put(innerClass.getInnerClassInfoIndex(), this.matchClassType());
                    expectedTypeMasks.put(innerClass.getOuterClassInfoIndex(), i -> i == 0 || i == 7);
                    expectedTypeMasks.put(innerClass.getInnerNameIndex(), i -> i == 0 || i == 1);
                    allow0Case |= innerClass.getInnerClassInfoIndex() == 0 || innerClass.getOuterClassInfoIndex() == 0 || innerClass.getInnerNameIndex() == 0;
                }
                break;
            }
            case "Code": {
                if (context != AttributeContext.METHOD) {
                    return false;
                }
                Method method = (Method)holder;
                if ((method.getAccess() & 0x400) > 0) {
                    logger.debug("Illegal 'Code' attribute on abstract method {}", (Object)this.pool.getUtf(method.getNameIndex()));
                    return false;
                }
                CodeAttribute code = (CodeAttribute)attribute;
                code.getAttributes().removeIf(sub -> !this.isValid(code, (Attribute)sub));
                for (CodeAttribute.ExceptionTableEntry exceptionTableEntry : code.getExceptionTable()) {
                    expectedTypeMasks.put(exceptionTableEntry.getCatchTypeIndex(), i -> i == 0 || i == 7);
                    if (exceptionTableEntry.getCatchTypeIndex() == 0) {
                        allow0Case = true;
                        continue;
                    }
                    cpEntryValidators.put(exceptionTableEntry.getCatchTypeIndex(), this.matchClassType());
                }
                break;
            }
            case "Signature": {
                SignatureAttribute signatureAttribute = (SignatureAttribute)attribute;
                expectedTypeMasks.put(signatureAttribute.getSignatureIndex(), i -> i == 1);
                cpEntryValidators.put(signatureAttribute.getSignatureIndex(), this.matchNonEmptyUtf8());
                break;
            }
            case "SourceFile": {
                SourceFileAttribute sourceFileAttribute = (SourceFileAttribute)attribute;
                expectedTypeMasks.put(sourceFileAttribute.getSourceFileNameIndex(), i -> i == 1);
                cpEntryValidators.put(sourceFileAttribute.getSourceFileNameIndex(), this.matchNonEmptyUtf8());
                break;
            }
            case "Module": {
                ModuleAttribute moduleAttribute = (ModuleAttribute)attribute;
                expectedTypeMasks.put(moduleAttribute.getModuleIndex(), i -> i == 19);
                expectedTypeMasks.put(moduleAttribute.getVersionIndex(), i -> i == 0 || i == 1);
                if (moduleAttribute.getVersionIndex() == 0) {
                    allow0Case = true;
                }
                for (ModuleAttribute.Requires requires : moduleAttribute.getRequires()) {
                    expectedTypeMasks.put(requires.getIndex(), i -> i == 19);
                    expectedTypeMasks.put(requires.getVersionIndex(), i -> i == 0 || i == 1);
                }
                for (ModuleAttribute.Exports exports : moduleAttribute.getExports()) {
                    expectedTypeMasks.put(exports.getIndex(), i -> i == 20);
                    for (int moduleIndex : exports.getToIndices()) {
                        expectedTypeMasks.put(moduleIndex, i -> i == 19);
                    }
                }
                for (ModuleAttribute.Opens opens : moduleAttribute.getOpens()) {
                    expectedTypeMasks.put(opens.getIndex(), i -> i == 20);
                    for (int moduleIndex : opens.getToIndices()) {
                        expectedTypeMasks.put(moduleIndex, i -> i == 19);
                    }
                }
                for (ModuleAttribute.Provides provides : moduleAttribute.getProvides()) {
                    expectedTypeMasks.put(provides.getIndex(), i -> i == 7);
                    for (int implIndex : provides.getWithIndices()) {
                        expectedTypeMasks.put(implIndex, i -> i == 7);
                    }
                }
                for (int use : moduleAttribute.getUses()) {
                    expectedTypeMasks.put(use, i -> i == 7);
                }
                break;
            }
            case "BootstrapMethods": {
                BootstrapMethodsAttribute bootstrapMethodsAttribute = (BootstrapMethodsAttribute)attribute;
                for (BootstrapMethodsAttribute.BootstrapMethod bootstrapMethod : bootstrapMethodsAttribute.getBootstrapMethods()) {
                    expectedTypeMasks.put(bootstrapMethod.getBsmMethodref(), i -> i == 15);
                    for (int arg : bootstrapMethod.getArgs()) {
                        expectedTypeMasks.put(arg, i -> i >= 3 && i <= 8 || i >= 15 && i <= 17);
                    }
                }
                break;
            }
            case "LocalVariableTable": {
                LocalVariableTableAttribute varTable = (LocalVariableTableAttribute)attribute;
                for (LocalVariableTableAttribute.VarEntry entry : varTable.getEntries()) {
                    expectedTypeMasks.put(entry.getNameIndex(), i -> i == 1);
                    expectedTypeMasks.put(entry.getDescIndex(), i -> i == 1);
                    cpEntryValidators.put(entry.getNameIndex(), this.matchNonEmptyUtf8().and(this.matchWordUtf8()));
                    cpEntryValidators.put(entry.getDescIndex(), this.matchNonEmptyUtf8());
                }
                break;
            }
            case "LocalVariableTypeTable": {
                LocalVariableTypeTableAttribute localVariableTypeTableAttribute = (LocalVariableTypeTableAttribute)attribute;
                for (LocalVariableTypeTableAttribute.VarTypeEntry entry : localVariableTypeTableAttribute.getEntries()) {
                    expectedTypeMasks.put(entry.getNameIndex(), i -> i == 1);
                    expectedTypeMasks.put(entry.getSignatureIndex(), i -> i == 1);
                    cpEntryValidators.put(entry.getNameIndex(), this.matchNonEmptyUtf8().and(this.matchWordUtf8()));
                    cpEntryValidators.put(entry.getSignatureIndex(), this.matchNonEmptyUtf8());
                }
                break;
            }
            case "PermittedSubclasses": {
                PermittedClassesAttribute permittedClassesAttribute = (PermittedClassesAttribute)attribute;
                for (int index : permittedClassesAttribute.getClasses()) {
                    expectedTypeMasks.put(index, i -> i == 7);
                    cpEntryValidators.put(index, this.matchClassType());
                }
                break;
            }
            case "Record": {
                RecordAttribute recordAttribute = (RecordAttribute)attribute;
                for (RecordAttribute.RecordComponent component : recordAttribute.getComponents()) {
                    expectedTypeMasks.put(component.getNameIndex(), i -> i == 1);
                    cpEntryValidators.put(component.getNameIndex(), this.matchWordUtf8());
                    expectedTypeMasks.put(component.getDescIndex(), i -> i == 1);
                    cpEntryValidators.put(component.getDescIndex(), this.matchNonEmptyUtf8());
                }
                break;
            }
            case "LineNumberTable": 
            case "SourceDebugExtension": 
            case "Deprecated": 
            case "Synthetic": {
                break;
            }
        }
        int min = allow0Case ? 0 : 1;
        for (Map.Entry entry : expectedTypeMasks.entrySet()) {
            int cpIndex = (Integer)entry.getKey();
            if (cpIndex < min || cpIndex > maxCpIndex) {
                logger.debug("Invalid '{}' attribute on {}, contains CP reference to index out of CP bounds!", (Object)name, (Object)context.name());
                return false;
            }
            if (allow0Case && cpIndex == 0) continue;
            ConstPoolEntry cpEntry = this.pool.get(cpIndex);
            if (cpEntry == null) {
                logger.debug("No CP entry at index '{}' in Attribute '{}' on {}", new Object[]{cpIndex, name, context.name()});
                return false;
            }
            int tag = cpEntry.getTag();
            if (!((Predicate)entry.getValue()).test(tag)) {
                logger.debug("Invalid '{}' attribute on {}, contains CP reference to index with wrong type!", (Object)name, (Object)context.name());
                return false;
            }
            if (!cpEntryValidators.containsKey(cpIndex) || ((Predicate)cpEntryValidators.get(cpIndex)).test(cpEntry)) continue;
            logger.debug("Invalid '{}' attribute, contains CP reference to item that does not match criteria at index: {}", (Object)name, (Object)cpIndex);
            return false;
        }
        return true;
    }

    private void addAnnotationValidation(AttributeHolder holder, Map<Integer, Predicate<Integer>> expectedTypeMasks, Map<Integer, Predicate<ConstPoolEntry>> cpEntryValidators, Annotation anno) {
        expectedTypeMasks.put(anno.getTypeIndex(), i -> i == 1);
        cpEntryValidators.put(anno.getTypeIndex(), this.matchUtf8ClassType());
        for (Map.Entry<Integer, ElementValue> entry : anno.getValues().entrySet()) {
            int elementTypeIndex = entry.getKey();
            expectedTypeMasks.put(elementTypeIndex, i -> i == 1);
            cpEntryValidators.put(elementTypeIndex, this.matchUtf8ClassType());
            this.addElementValueValidation(holder, expectedTypeMasks, cpEntryValidators, entry.getValue());
        }
        if (anno instanceof TypeAnnotation) {
            TypeAnnotation typeAnnotation = (TypeAnnotation)anno;
            TargetInfo targetInfo = typeAnnotation.getTargetInfo();
            switch (targetInfo.getTargetTypeKind()) {
                case TYPE_PARAMETER_BOUND_TARGET: {
                    break;
                }
                case TYPE_PARAMETER_TARGET: {
                    break;
                }
                case FORMAL_PARAMETER_TARGET: {
                    break;
                }
                case TYPE_ARGUMENT_TARGET: {
                    break;
                }
                case LOCALVAR_TARGET: {
                    break;
                }
                case THROWS_TARGET: {
                    break;
                }
                case OFFSET_TARGET: {
                    break;
                }
                case SUPERTYPE_TARGET: {
                    if (holder instanceof ClassFile) {
                        TargetInfo.SuperTypeTargetInfo superTypeTargetInfo = (TargetInfo.SuperTypeTargetInfo)targetInfo;
                        if (superTypeTargetInfo.isExtends()) break;
                        ClassFile classFile = (ClassFile)holder;
                        if (superTypeTargetInfo.getSuperTypeIndex() < classFile.getInterfaceIndices().size()) break;
                        expectedTypeMasks.put(-1, i -> false);
                        break;
                    }
                    expectedTypeMasks.put(-1, i -> false);
                    break;
                }
                case CATCH_TARGET: {
                    if (holder instanceof CodeAttribute) {
                        CodeAttribute code = (CodeAttribute)holder;
                        TargetInfo.CatchTargetInfo catchTargetInfo = (TargetInfo.CatchTargetInfo)targetInfo;
                        if (catchTargetInfo.getExceptionTableIndex() < code.getExceptionTable().size()) break;
                        expectedTypeMasks.put(-1, i -> false);
                        break;
                    }
                    expectedTypeMasks.put(-1, i -> false);
                    break;
                }
            }
        }
    }

    private void addElementValueValidation(AttributeHolder holder, Map<Integer, Predicate<Integer>> expectedTypeMasks, Map<Integer, Predicate<ConstPoolEntry>> cpEntryValidators, ElementValue elementValue) {
        if (elementValue instanceof ClassElementValue) {
            int classIndex = ((ClassElementValue)elementValue).getClassIndex();
            expectedTypeMasks.put(classIndex, i -> i == 1);
            cpEntryValidators.put(classIndex, this.matchUtf8ClassType());
        } else if (elementValue instanceof EnumElementValue) {
            EnumElementValue enumElementValue = (EnumElementValue)elementValue;
            expectedTypeMasks.put(enumElementValue.getNameIndex(), i -> i == 1);
            expectedTypeMasks.put(enumElementValue.getTypeIndex(), i -> i == 1);
            cpEntryValidators.put(enumElementValue.getTypeIndex(), this.matchUtf8ClassType());
        } else if (elementValue instanceof Utf8ElementValue) {
            int utfIndex = ((Utf8ElementValue)elementValue).getUtfIndex();
            expectedTypeMasks.put(utfIndex, i -> i == 1);
        } else if (elementValue instanceof PrimitiveElementValue) {
            int primValueIndex = ((PrimitiveElementValue)elementValue).getValueIndex();
            expectedTypeMasks.put(primValueIndex, i -> i >= 3 && i <= 6);
        }
    }

    private Predicate<ConstPoolEntry> matchClassType() {
        return e -> e instanceof CpClass && this.matchUtf8ClassType().test(this.pool.get(((CpClass)e).getIndex()));
    }

    private Predicate<ConstPoolEntry> matchUtf8ClassType() {
        return this.matchNonEmptyUtf8();
    }

    private Predicate<ConstPoolEntry> matchNonEmptyUtf8() {
        return e -> e instanceof CpUtf8 && ((CpUtf8)e).getText().length() > 0;
    }

    private Predicate<ConstPoolEntry> matchWordUtf8() {
        return e -> e instanceof CpUtf8 && ((CpUtf8)e).getText().matches("[<>;/$\\w]+");
    }
}

