/*
 * Decompiled with CFR 0.152.
 */
package com.antgroup.antchain.myjava.model.lowlevel;

import com.antgroup.antchain.myjava.common.DominatorTree;
import com.antgroup.antchain.myjava.common.Graph;
import com.antgroup.antchain.myjava.common.GraphUtils;
import com.antgroup.antchain.myjava.model.BasicBlock;
import com.antgroup.antchain.myjava.model.Incoming;
import com.antgroup.antchain.myjava.model.Instruction;
import com.antgroup.antchain.myjava.model.MethodReference;
import com.antgroup.antchain.myjava.model.Phi;
import com.antgroup.antchain.myjava.model.Program;
import com.antgroup.antchain.myjava.model.TextLocation;
import com.antgroup.antchain.myjava.model.TryCatchBlock;
import com.antgroup.antchain.myjava.model.ValueType;
import com.antgroup.antchain.myjava.model.Variable;
import com.antgroup.antchain.myjava.model.instructions.BinaryBranchingCondition;
import com.antgroup.antchain.myjava.model.instructions.BinaryBranchingInstruction;
import com.antgroup.antchain.myjava.model.instructions.BoundCheckInstruction;
import com.antgroup.antchain.myjava.model.instructions.CastInstruction;
import com.antgroup.antchain.myjava.model.instructions.CloneArrayInstruction;
import com.antgroup.antchain.myjava.model.instructions.ConstructArrayInstruction;
import com.antgroup.antchain.myjava.model.instructions.ConstructInstruction;
import com.antgroup.antchain.myjava.model.instructions.ConstructMultiArrayInstruction;
import com.antgroup.antchain.myjava.model.instructions.DoubleConstantInstruction;
import com.antgroup.antchain.myjava.model.instructions.ExitInstruction;
import com.antgroup.antchain.myjava.model.instructions.FloatConstantInstruction;
import com.antgroup.antchain.myjava.model.instructions.InitClassInstruction;
import com.antgroup.antchain.myjava.model.instructions.IntegerConstantInstruction;
import com.antgroup.antchain.myjava.model.instructions.InvocationType;
import com.antgroup.antchain.myjava.model.instructions.InvokeInstruction;
import com.antgroup.antchain.myjava.model.instructions.JumpInstruction;
import com.antgroup.antchain.myjava.model.instructions.LongConstantInstruction;
import com.antgroup.antchain.myjava.model.instructions.MonitorEnterInstruction;
import com.antgroup.antchain.myjava.model.instructions.MonitorExitInstruction;
import com.antgroup.antchain.myjava.model.instructions.NullCheckInstruction;
import com.antgroup.antchain.myjava.model.instructions.NullConstantInstruction;
import com.antgroup.antchain.myjava.model.instructions.RaiseInstruction;
import com.antgroup.antchain.myjava.model.instructions.SwitchInstruction;
import com.antgroup.antchain.myjava.model.instructions.SwitchTableEntry;
import com.antgroup.antchain.myjava.model.lowlevel.CallSiteDescriptor;
import com.antgroup.antchain.myjava.model.lowlevel.CallSiteLocation;
import com.antgroup.antchain.myjava.model.lowlevel.Characteristics;
import com.antgroup.antchain.myjava.model.lowlevel.ExceptionHandlerDescriptor;
import com.antgroup.antchain.myjava.model.util.DefinitionExtractor;
import com.antgroup.antchain.myjava.model.util.ProgramUtils;
import com.antgroup.antchain.myjava.runtime.ExceptionHandling;
import com.antgroup.antchain.myjava.runtime.ShadowStack;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.teavm.hppc.IntHashSet;
import org.teavm.hppc.IntObjectHashMap;
import org.teavm.hppc.IntSet;

public class ExceptionHandlingShadowStackContributor {
    private static final MethodReference FILL_STACK_TRACE = new MethodReference(ExceptionHandling.class, "fillStackTrace", StackTraceElement[].class);
    private Characteristics characteristics;
    private List<CallSiteDescriptor> callSites;
    private BasicBlock defaultExceptionHandler;
    private MethodReference method;
    private Program program;
    private DominatorTree dom;
    private BasicBlock[] variableDefinitionPlaces;
    private boolean hasExceptionHandlers;
    private int parameterCount;
    public int callSiteIdGen;
    private static final boolean DISABLE_NESTED_FINALLY_IN_TRY_CATCH = true;

    public ExceptionHandlingShadowStackContributor(Characteristics characteristics, List<CallSiteDescriptor> callSites, MethodReference method, Program program) {
        this.characteristics = characteristics;
        this.callSites = callSites;
        this.method = method;
        this.program = program;
        Graph cfg = ProgramUtils.buildControlFlowGraph(program);
        this.dom = GraphUtils.buildDominatorTree(cfg);
        this.variableDefinitionPlaces = ProgramUtils.getVariableDefinitionPlaces(program);
        this.parameterCount = method.parameterCount() + 1;
    }

    public boolean contribute() {
        int[] blockMapping = new int[this.program.basicBlockCount()];
        for (int i = 0; i < blockMapping.length; ++i) {
            blockMapping[i] = i;
        }
        ArrayList<Phi> allPhis = new ArrayList<Phi>();
        int blockCount = this.program.basicBlockCount();
        for (int i = 0; i < blockCount; ++i) {
            allPhis.addAll(this.program.basicBlockAt(i).getPhis());
        }
        HashSet<BasicBlock> exceptionHandlers = new HashSet<BasicBlock>();
        for (int i = 0; i < blockCount; ++i) {
            int newIndex;
            BasicBlock block = this.program.basicBlockAt(i);
            for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) {
                exceptionHandlers.add(tryCatch.getHandler());
            }
            if (block.getExceptionVariable() != null) {
                InvokeInstruction catchCall = new InvokeInstruction();
                catchCall.setType(InvocationType.SPECIAL);
                catchCall.setMethod(new MethodReference(ExceptionHandling.class, "catchException", Throwable.class));
                catchCall.setReceiver(block.getExceptionVariable());
                block.addFirst(catchCall);
                block.setExceptionVariable(null);
            }
            if ((newIndex = this.contributeToBasicBlock(block)) == i) continue;
            blockMapping[i] = newIndex;
            this.hasExceptionHandlers = true;
        }
        for (Phi phi : allPhis) {
            if (exceptionHandlers.contains(phi.getBasicBlock())) continue;
            for (Incoming incoming : phi.getIncomings()) {
                int mappedSource = blockMapping[incoming.getSource().getIndex()];
                incoming.setSource(this.program.basicBlockAt(mappedSource));
            }
        }
        return this.hasExceptionHandlers;
    }

    private int contributeToBasicBlock(BasicBlock block) {
        int[] currentJointSources = new int[this.program.variableCount()];
        IntObjectHashMap<int[]> jointReceiverMaps = new IntObjectHashMap<int[]>();
        Arrays.fill(currentJointSources, -1);
        IntHashSet variablesDefinedHere = new IntHashSet();
        for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) {
            int[] jointReceiverMap = new int[this.program.variableCount()];
            Arrays.fill(jointReceiverMap, -1);
            for (Phi phi : tryCatch.getHandler().getPhis()) {
                List sourceVariables = phi.getIncomings().stream().filter(incoming -> incoming.getSource() == tryCatch.getProtectedBlock()).map(incoming -> incoming.getValue()).collect(Collectors.toList());
                if (sourceVariables.isEmpty()) continue;
                for (Variable sourceVar : sourceVariables) {
                    BasicBlock sourceVarDefinedAt = this.variableDefinitionPlaces[sourceVar.getIndex()];
                    if (sourceVar.getIndex() >= this.parameterCount && (!this.dom.dominates(sourceVarDefinedAt.getIndex(), block.getIndex()) || block == sourceVarDefinedAt)) continue;
                    currentJointSources[phi.getReceiver().getIndex()] = sourceVar.getIndex();
                    if (sourceVarDefinedAt == block) continue;
                    break;
                }
                for (Variable sourceVar : sourceVariables) {
                    jointReceiverMap[sourceVar.getIndex()] = phi.getReceiver().getIndex();
                }
            }
            jointReceiverMaps.put(tryCatch.getHandler().getIndex(), jointReceiverMap);
        }
        for (Phi phi : block.getPhis()) {
            Variable definedVar = phi.getReceiver();
            for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) {
                int jointReceiver = ((int[])jointReceiverMaps.get(tryCatch.getHandler().getIndex()))[definedVar.getIndex()];
                if (jointReceiver < 0) continue;
                currentJointSources[jointReceiver] = definedVar.getIndex();
            }
            variablesDefinedHere.add(definedVar.getIndex());
        }
        DefinitionExtractor defExtractor = new DefinitionExtractor();
        ArrayList<BasicBlock> blocksToClearHandlers = new ArrayList<BasicBlock>();
        blocksToClearHandlers.add(block);
        BasicBlock initialBlock = block;
        for (Instruction insn : block) {
            BasicBlock next;
            insn.acceptVisitor(defExtractor);
            for (Variable definedVar : defExtractor.getDefinedVariables()) {
                for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) {
                    int jointReceiver = ((int[])jointReceiverMaps.get(tryCatch.getHandler().getIndex()))[definedVar.getIndex()];
                    if (jointReceiver < 0) continue;
                    currentJointSources[jointReceiver] = definedVar.getIndex();
                }
                variablesDefinedHere.add(definedVar.getIndex());
            }
            if (!this.isCallInstruction(insn)) continue;
            boolean last = false;
            if (this.isSpecialCallInstruction(insn)) {
                next = null;
                while (insn.getNext() != null) {
                    Instruction nextInsn = insn.getNext();
                    nextInsn.delete();
                }
                last = true;
            } else if (insn instanceof RaiseInstruction) {
                InvokeInstruction raise = new InvokeInstruction();
                raise.setMethod(new MethodReference(ExceptionHandling.class, "throwException", Throwable.class, Void.TYPE));
                raise.setType(InvocationType.SPECIAL);
                raise.setArguments(((RaiseInstruction)insn).getException());
                raise.setLocation(insn.getLocation());
                insn.replace(raise);
                insn = raise;
                next = null;
            } else if (insn.getNext() != null && insn.getNext() instanceof JumpInstruction) {
                next = ((JumpInstruction)insn.getNext()).getTarget();
                insn.getNext().delete();
                last = true;
            } else {
                next = this.program.createBasicBlock();
                next.getTryCatchBlocks().addAll(ProgramUtils.copyTryCatches(block, this.program));
                blocksToClearHandlers.add(next);
                while (insn.getNext() != null) {
                    Instruction nextInsn = insn.getNext();
                    nextInsn.delete();
                    next.add(nextInsn);
                }
            }
            CallSiteLocation[] locations = CallSiteLocation.fromTextLocation(insn.getLocation(), this.method);
            CallSiteDescriptor callSite = new CallSiteDescriptor(this.callSiteIdGen++, locations);
            this.callSites.add(callSite);
            List<Instruction> pre = this.setLocation(this.getInstructionsBeforeCallSite(callSite), insn.getLocation());
            List<Instruction> post = this.getInstructionsAfterCallSite(initialBlock, block, next, callSite, currentJointSources, variablesDefinedHere);
            post = this.setLocation(post, insn.getLocation());
            block.getLastInstruction().insertPreviousAll(pre);
            block.addAll(post);
            this.hasExceptionHandlers = true;
            if (next == null || last) break;
            block = next;
            variablesDefinedHere.clear();
        }
        this.fixOutgoingPhis(initialBlock, block, currentJointSources, variablesDefinedHere);
        for (BasicBlock blockToClear : blocksToClearHandlers) {
            blockToClear.getTryCatchBlocks().clear();
        }
        return block.getIndex();
    }

    private boolean isCallInstruction(Instruction insn) {
        return ExceptionHandlingShadowStackContributor.isCallInstruction(this.characteristics, insn);
    }

    public static boolean isCallInstruction(Characteristics characteristics, Instruction insn) {
        if (insn instanceof InitClassInstruction || insn instanceof ConstructInstruction || insn instanceof ConstructArrayInstruction || insn instanceof ConstructMultiArrayInstruction || insn instanceof CloneArrayInstruction || insn instanceof RaiseInstruction || insn instanceof MonitorEnterInstruction || insn instanceof MonitorExitInstruction || insn instanceof NullCheckInstruction || insn instanceof BoundCheckInstruction || insn instanceof CastInstruction) {
            return true;
        }
        if (insn instanceof InvokeInstruction) {
            return ExceptionHandlingShadowStackContributor.isManagedMethodCall(characteristics, ((InvokeInstruction)insn).getMethod());
        }
        return false;
    }

    public static boolean isManagedMethodCall(Characteristics characteristics, MethodReference method) {
        if (characteristics.isManaged(method) || method.equals(FILL_STACK_TRACE)) {
            return true;
        }
        return method.getClassName().equals(ExceptionHandling.class.getName()) && method.getName().startsWith("throw");
    }

    private boolean isSpecialCallInstruction(Instruction insn) {
        if (!(insn instanceof InvokeInstruction)) {
            return false;
        }
        MethodReference method = ((InvokeInstruction)insn).getMethod();
        return method.getClassName().equals(ExceptionHandling.class.getName()) && method.getName().startsWith("throw");
    }

    private List<Instruction> setLocation(List<Instruction> instructions, TextLocation location) {
        if (location != null) {
            for (Instruction instruction : instructions) {
                instruction.setLocation(location);
            }
        }
        return instructions;
    }

    private List<Instruction> getInstructionsBeforeCallSite(CallSiteDescriptor callSite) {
        ArrayList<Instruction> instructions = new ArrayList<Instruction>();
        Variable idVariable = this.program.createVariable();
        IntegerConstantInstruction idInsn = new IntegerConstantInstruction();
        idInsn.setConstant(callSite.getId());
        idInsn.setReceiver(idVariable);
        instructions.add(idInsn);
        InvokeInstruction registerInsn = new InvokeInstruction();
        registerInsn.setMethod(new MethodReference(ShadowStack.class, "registerCallSite", Integer.TYPE, Void.TYPE));
        registerInsn.setType(InvocationType.SPECIAL);
        registerInsn.setArguments(idVariable);
        instructions.add(registerInsn);
        return instructions;
    }

    private List<Instruction> getInstructionsAfterCallSite(BasicBlock initialBlock, BasicBlock block, BasicBlock next, CallSiteDescriptor callSite, int[] currentJointSources, IntSet variablesDefinedHere) {
        Program program = block.getProgram();
        ArrayList<Instruction> instructions = new ArrayList<Instruction>();
        Variable handlerIdVariable = program.createVariable();
        InvokeInstruction getHandlerIdInsn = new InvokeInstruction();
        getHandlerIdInsn.setMethod(new MethodReference(ShadowStack.class, "getExceptionHandlerId", Integer.TYPE));
        getHandlerIdInsn.setType(InvocationType.SPECIAL);
        getHandlerIdInsn.setReceiver(handlerIdVariable);
        instructions.add(getHandlerIdInsn);
        SwitchInstruction switchInsn = new SwitchInstruction();
        switchInsn.setCondition(handlerIdVariable);
        if (next != null) {
            SwitchTableEntry continueExecutionEntry = new SwitchTableEntry();
            int usingHandlerId = callSite.getId();
            continueExecutionEntry.setCondition(usingHandlerId);
            continueExecutionEntry.setTarget(next);
            switchInsn.getEntries().add(continueExecutionEntry);
        }
        int finallyBlocksCount = 0;
        for (TryCatchBlock tryCatchBlock : block.getTryCatchBlocks()) {
            if (tryCatchBlock.getExceptionType() != null) continue;
            ++finallyBlocksCount;
        }
        if (finallyBlocksCount > 1) {
            TextLocation firstLocation = block.getFirstInstruction().getLocation() != null ? block.getFirstInstruction().getLocation() : (block.getFirstInstruction().getNext() != null ? block.getFirstInstruction().getNext().getLocation() : null);
            throw new RuntimeException("Nested finally block in try or catch block is disabled " + firstLocation);
        }
        boolean defaultExists = false;
        AtomicInteger nextHandlerIdGenerator = new AtomicInteger(callSite.getId() + 10000);
        for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) {
            ExceptionHandlerDescriptor handler = new ExceptionHandlerDescriptor(nextHandlerIdGenerator.incrementAndGet(), tryCatch.getExceptionType());
            callSite.getHandlers().add(handler);
            if (tryCatch.getExceptionType() == null) {
                defaultExists = true;
                switchInsn.setDefaultTarget(tryCatch.getHandler());
                continue;
            }
            SwitchTableEntry catchEntry = new SwitchTableEntry();
            catchEntry.setTarget(tryCatch.getHandler());
            catchEntry.setCondition(handler.getId());
            switchInsn.getEntries().add(catchEntry);
        }
        this.fixOutgoingPhis(initialBlock, block, currentJointSources, variablesDefinedHere);
        if (!defaultExists) {
            switchInsn.setDefaultTarget(this.getDefaultExceptionHandler());
        }
        if (switchInsn.getEntries().isEmpty()) {
            instructions.clear();
            JumpInstruction jumpInstruction = new JumpInstruction();
            jumpInstruction.setTarget(switchInsn.getDefaultTarget());
            instructions.add(jumpInstruction);
        } else if (switchInsn.getEntries().size() == 1) {
            SwitchTableEntry switchTableEntry = switchInsn.getEntries().get(0);
            IntegerConstantInstruction singleTestConstant = new IntegerConstantInstruction();
            singleTestConstant.setConstant(switchTableEntry.getCondition());
            singleTestConstant.setReceiver(program.createVariable());
            instructions.add(singleTestConstant);
            BinaryBranchingInstruction branching = new BinaryBranchingInstruction(BinaryBranchingCondition.EQUAL);
            branching.setConsequent(switchTableEntry.getTarget());
            branching.setAlternative(switchInsn.getDefaultTarget());
            branching.setFirstOperand(switchInsn.getCondition());
            branching.setSecondOperand(singleTestConstant.getReceiver());
            instructions.add(branching);
        } else {
            instructions.add(switchInsn);
        }
        return instructions;
    }

    private void fixOutgoingPhis(BasicBlock block, BasicBlock newBlock, int[] currentJointSources, IntSet variablesDefinedHere) {
        for (TryCatchBlock tryCatch : block.getTryCatchBlocks()) {
            for (Phi phi : tryCatch.getHandler().getPhis()) {
                int value = currentJointSources[phi.getReceiver().getIndex()];
                if (value < 0) continue;
                ArrayList<Incoming> additionalIncomings = new ArrayList<Incoming>();
                for (int i = 0; i < phi.getIncomings().size(); ++i) {
                    Incoming incoming = phi.getIncomings().get(i);
                    if (incoming.getSource() != block || incoming.getSource() == newBlock || incoming.getValue().getIndex() != value) continue;
                    if (variablesDefinedHere.contains(value)) {
                        incoming.setSource(newBlock);
                        continue;
                    }
                    Incoming incomingCopy = new Incoming();
                    incomingCopy.setSource(newBlock);
                    incomingCopy.setValue(incoming.getValue());
                    additionalIncomings.add(incomingCopy);
                }
                phi.getIncomings().addAll(additionalIncomings);
            }
        }
    }

    private BasicBlock getDefaultExceptionHandler() {
        if (this.defaultExceptionHandler == null) {
            this.defaultExceptionHandler = this.program.createBasicBlock();
            Variable result = this.createReturnValueInstructions(this.defaultExceptionHandler);
            ExitInstruction exit = new ExitInstruction();
            exit.setValueToReturn(result);
            this.defaultExceptionHandler.add(exit);
        }
        return this.defaultExceptionHandler;
    }

    private Variable createReturnValueInstructions(BasicBlock block) {
        ValueType returnType = this.method.getReturnType();
        if (returnType == ValueType.VOID) {
            return null;
        }
        Variable variable = this.program.createVariable();
        if (returnType instanceof ValueType.Primitive) {
            switch (((ValueType.Primitive)returnType).getKind()) {
                case BOOLEAN: 
                case BYTE: 
                case SHORT: 
                case CHARACTER: 
                case INTEGER: {
                    IntegerConstantInstruction intConstant = new IntegerConstantInstruction();
                    intConstant.setReceiver(variable);
                    block.add(intConstant);
                    return variable;
                }
                case LONG: {
                    LongConstantInstruction longConstant = new LongConstantInstruction();
                    longConstant.setReceiver(variable);
                    block.add(longConstant);
                    return variable;
                }
                case FLOAT: {
                    FloatConstantInstruction floatConstant = new FloatConstantInstruction();
                    floatConstant.setReceiver(variable);
                    block.add(floatConstant);
                    return variable;
                }
                case DOUBLE: {
                    DoubleConstantInstruction doubleConstant = new DoubleConstantInstruction();
                    doubleConstant.setReceiver(variable);
                    block.add(doubleConstant);
                    return variable;
                }
            }
        }
        NullConstantInstruction nullConstant = new NullConstantInstruction();
        nullConstant.setReceiver(variable);
        block.add(nullConstant);
        return variable;
    }
}

