/*
 * Decompiled with CFR 0.152.
 */
package edu.columbia.cs.psl.phosphor.instrumenter;

import edu.columbia.cs.psl.phosphor.Configuration;
import edu.columbia.cs.psl.phosphor.SourceSinkManager;
import edu.columbia.cs.psl.phosphor.TaintUtils;
import edu.columbia.cs.psl.phosphor.org.objectweb.asm.ClassVisitor;
import edu.columbia.cs.psl.phosphor.org.objectweb.asm.Label;
import edu.columbia.cs.psl.phosphor.org.objectweb.asm.MethodVisitor;
import edu.columbia.cs.psl.phosphor.org.objectweb.asm.Opcodes;
import edu.columbia.cs.psl.phosphor.org.objectweb.asm.Type;
import edu.columbia.cs.psl.phosphor.runtime.MultiTainter;
import edu.columbia.cs.psl.phosphor.runtime.Taint;
import edu.columbia.cs.psl.phosphor.struct.ControlTaintTagStack;
import edu.columbia.cs.psl.phosphor.struct.SerializationWrapper;

public class SerializationFixingCV
extends ClassVisitor
implements Opcodes {
    private static final String INPUT_STREAM_NAME = "java/io/ObjectInputStream";
    private static final String OUTPUT_STREAM_NAME = "java/io/ObjectOutputStream";
    private static final String STREAM_CLASS_NAME = "java/io/ObjectStreamClass";
    private static final byte TC_OBJECT = 115;
    private static final byte TC_NULL = 112;
    private final String className;

    public SerializationFixingCV(ClassVisitor cv, String className) {
        super(Configuration.ASM_VERSION, cv);
        this.className = className;
    }

    public static boolean isApplicable(String className) {
        return Configuration.MULTI_TAINTING && (INPUT_STREAM_NAME.equals(className) || OUTPUT_STREAM_NAME.equals(className) || STREAM_CLASS_NAME.equals(className));
    }

    @Override
    public MethodVisitor visitMethod(int access, String name, String desc, String signature, String[] exceptions) {
        MethodVisitor mv = super.visitMethod(access, name, desc, signature, exceptions);
        if (STREAM_CLASS_NAME.equals(this.className)) {
            return new StreamClassMV(mv);
        }
        switch (name) {
            case "writeObject": 
            case "writeObject$$PHOSPHORTAGGED": 
            case "writeObject0$$PHOSPHORTAGGED": {
                return new ObjectWriteMV(mv);
            }
            case "readObject": 
            case "readObject$$PHOSPHORTAGGED": 
            case "readObject0$$PHOSPHORTAGGED": {
                return new ObjectReadMV(mv);
            }
            case "writeInt$$PHOSPHORTAGGED": 
            case "writeLong$$PHOSPHORTAGGED": 
            case "writeBoolean$$PHOSPHORTAGGED": 
            case "writeShort$$PHOSPHORTAGGED": 
            case "writeDouble$$PHOSPHORTAGGED": 
            case "writeByte$$PHOSPHORTAGGED": 
            case "writeChar$$PHOSPHORTAGGED": 
            case "writeFloat$$PHOSPHORTAGGED": {
                return new PrimitiveWriteMV(mv);
            }
            case "readInt$$PHOSPHORTAGGED": 
            case "readLong$$PHOSPHORTAGGED": 
            case "readBoolean$$PHOSPHORTAGGED": 
            case "readShort$$PHOSPHORTAGGED": 
            case "readDouble$$PHOSPHORTAGGED": 
            case "readByte$$PHOSPHORTAGGED": 
            case "readChar$$PHOSPHORTAGGED": 
            case "readFloat$$PHOSPHORTAGGED": 
            case "readUnsignedByte$$PHOSPHORTAGGED": 
            case "readUnsignedShort$$PHOSPHORTAGGED": {
                return new PrimitiveReadMV(mv, Type.getReturnType(desc));
            }
        }
        return mv;
    }

    public static Object wrapIfNecessary(Object obj) {
        Taint tag;
        if ((obj instanceof Boolean || obj instanceof Byte || obj instanceof Character || obj instanceof Short) && (tag = MultiTainter.getTaint(obj)) != null && !tag.isEmpty()) {
            if (obj instanceof Boolean) {
                return SerializationWrapper.wrap((Boolean)obj);
            }
            if (obj instanceof Byte) {
                return SerializationWrapper.wrap((Byte)obj);
            }
            if (obj instanceof Character) {
                return SerializationWrapper.wrap((Character)obj);
            }
            return SerializationWrapper.wrap((Short)obj);
        }
        return obj;
    }

    public static Object unwrapIfNecessary(Object obj) {
        if (obj instanceof SerializationWrapper) {
            return ((SerializationWrapper)obj).unwrap();
        }
        return obj;
    }

    private static class ObjectReadMV
    extends MethodVisitor {
        ObjectReadMV(MethodVisitor mv) {
            super(Configuration.ASM_VERSION, mv);
        }

        @Override
        public void visitInsn(int opcode) {
            if (TaintUtils.isReturnOpcode(opcode)) {
                super.visitMethodInsn(184, Type.getInternalName(SerializationFixingCV.class), "unwrapIfNecessary", "(Ljava/lang/Object;)Ljava/lang/Object;", false);
            }
            super.visitInsn(opcode);
        }
    }

    private static class ObjectWriteMV
    extends MethodVisitor {
        ObjectWriteMV(MethodVisitor mv) {
            super(Configuration.ASM_VERSION, mv);
        }

        @Override
        public void visitCode() {
            super.visitCode();
            super.visitVarInsn(25, 1);
            super.visitMethodInsn(184, Type.getInternalName(SerializationFixingCV.class), "wrapIfNecessary", "(Ljava/lang/Object;)Ljava/lang/Object;", false);
            super.visitVarInsn(58, 1);
        }
    }

    private static class PrimitiveReadMV
    extends MethodVisitor {
        private final Type returnType;

        PrimitiveReadMV(MethodVisitor mv, Type returnType) {
            super(Configuration.ASM_VERSION, mv);
            this.returnType = returnType;
        }

        @Override
        public void visitCode() {
            super.visitCode();
            Label label1 = new Label();
            Label label2 = new Label();
            Label label3 = new Label();
            Label label4 = new Label();
            Label label5 = new Label();
            super.visitVarInsn(25, 0);
            super.visitFieldInsn(180, SerializationFixingCV.INPUT_STREAM_NAME, "bin", "Ljava/io/ObjectInputStream$BlockDataInputStream;");
            super.visitMethodInsn(182, "java/io/ObjectInputStream$BlockDataInputStream", "getBlockDataMode", "()Z", false);
            super.visitJumpInsn(153, label1);
            super.visitVarInsn(25, 0);
            super.visitFieldInsn(180, SerializationFixingCV.INPUT_STREAM_NAME, "bin", "Ljava/io/ObjectInputStream$BlockDataInputStream;");
            super.visitMethodInsn(182, "java/io/ObjectInputStream$BlockDataInputStream", "currentBlockRemaining", "()I", false);
            super.visitJumpInsn(154, label2);
            super.visitLabel(label1);
            super.visitVarInsn(25, 0);
            super.visitFieldInsn(180, SerializationFixingCV.INPUT_STREAM_NAME, "bin", "Ljava/io/ObjectInputStream$BlockDataInputStream;");
            super.visitInsn(89);
            super.visitMethodInsn(182, "java/io/ObjectInputStream$BlockDataInputStream", "getBlockDataMode", "()Z", false);
            super.visitVarInsn(25, 0);
            super.visitFieldInsn(180, SerializationFixingCV.INPUT_STREAM_NAME, "bin", "Ljava/io/ObjectInputStream$BlockDataInputStream;");
            super.visitInsn(3);
            super.visitMethodInsn(182, "java/io/ObjectInputStream$BlockDataInputStream", "setBlockDataMode", "(Z)Z", false);
            super.visitInsn(87);
            super.visitVarInsn(25, 0);
            super.visitFieldInsn(180, SerializationFixingCV.INPUT_STREAM_NAME, "bin", "Ljava/io/ObjectInputStream$BlockDataInputStream;");
            super.visitMethodInsn(182, "java/io/ObjectInputStream$BlockDataInputStream", "peek", "()I", false);
            super.visitInsn(91);
            super.visitInsn(87);
            super.visitMethodInsn(182, "java/io/ObjectInputStream$BlockDataInputStream", "setBlockDataMode", "(Z)Z", false);
            super.visitInsn(87);
            super.visitInsn(89);
            super.visitIntInsn(16, 115);
            super.visitJumpInsn(159, label4);
            super.visitIntInsn(16, 112);
            super.visitJumpInsn(159, label5);
            super.visitJumpInsn(167, label2);
            super.visitLabel(label4);
            super.visitInsn(87);
            super.visitLabel(label5);
            super.visitVarInsn(25, 0);
            super.visitMethodInsn(182, SerializationFixingCV.INPUT_STREAM_NAME, "readObject", "()Ljava/lang/Object;", false);
            super.visitTypeInsn(192, Configuration.TAINT_TAG_INTERNAL_NAME);
            super.visitJumpInsn(167, label3);
            super.visitLabel(label2);
            super.visitInsn(1);
            super.visitLabel(label3);
        }

        @Override
        public void visitInsn(int opcode) {
            if (TaintUtils.isReturnOpcode(opcode)) {
                super.visitInsn(90);
                super.visitInsn(95);
                super.visitFieldInsn(181, this.returnType.getInternalName(), "taint", Configuration.TAINT_TAG_DESC);
            }
            super.visitInsn(opcode);
        }
    }

    private static class PrimitiveWriteMV
    extends MethodVisitor {
        PrimitiveWriteMV(MethodVisitor mv) {
            super(Configuration.ASM_VERSION, mv);
        }

        @Override
        public void visitCode() {
            super.visitCode();
            Label label1 = new Label();
            super.visitVarInsn(25, 1);
            super.visitJumpInsn(198, label1);
            super.visitVarInsn(25, 1);
            super.visitMethodInsn(182, Configuration.TAINT_TAG_INTERNAL_NAME, "isEmpty", "()Z", false);
            super.visitJumpInsn(154, label1);
            super.visitVarInsn(25, 0);
            super.visitVarInsn(25, 1);
            super.visitMethodInsn(182, SerializationFixingCV.OUTPUT_STREAM_NAME, "writeObject", "(Ljava/lang/Object;)V", false);
            super.visitLabel(label1);
        }
    }

    private static class StreamClassMV
    extends MethodVisitor {
        StreamClassMV(MethodVisitor mv) {
            super(Configuration.ASM_VERSION, mv);
        }

        @Override
        public void visitMethodInsn(int opcode, String owner, String name, String desc, boolean isInterface) {
            Type[] args;
            if (SerializationFixingCV.OUTPUT_STREAM_NAME.equals(owner) && name.startsWith("write") && (args = Type.getArgumentTypes(desc)).length > 0 && Type.getType(Configuration.TAINT_TAG_DESC).equals(args[0])) {
                boolean widePrimitive;
                String untaintedMethod = name.replace("$$PHOSPHORTAGGED", "");
                String untaintedDesc = SourceSinkManager.remapMethodDescToRemoveTaints(desc);
                boolean bl = widePrimitive = Type.DOUBLE_TYPE.equals(args[1]) || Type.LONG_TYPE.equals(args[1]);
                if (args.length == 2) {
                    super.visitInsn(widePrimitive ? 93 : 90);
                    super.visitInsn(widePrimitive ? 88 : 87);
                    super.visitInsn(87);
                    super.visitMethodInsn(opcode, owner, untaintedMethod, untaintedDesc, isInterface);
                    return;
                }
                if (args.length == 3 && args[2].equals(Type.getType(ControlTaintTagStack.class))) {
                    super.visitInsn(87);
                    super.visitInsn(widePrimitive ? 93 : 90);
                    super.visitInsn(widePrimitive ? 88 : 87);
                    super.visitInsn(87);
                    super.visitMethodInsn(opcode, owner, untaintedMethod, untaintedDesc, isInterface);
                    return;
                }
            }
            super.visitMethodInsn(opcode, owner, name, desc, isInterface);
        }
    }
}

