package com.newrelic.agent.instrumentation.context;

import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Sets;
import com.newrelic.agent.Agent;
import com.newrelic.agent.bridge.AgentBridge;
import com.newrelic.agent.instrumentation.PointCut;
import com.newrelic.agent.instrumentation.classmatchers.OptimizedClassMatcher.Match;
import com.newrelic.agent.instrumentation.tracing.TraceClassVisitor;
import com.newrelic.agent.instrumentation.tracing.TraceDetails;
import com.newrelic.agent.util.asm.BenignClassReadException;
import com.newrelic.agent.util.asm.ClassResolver;
import com.newrelic.agent.util.asm.ClassResolvers;
import com.newrelic.agent.util.asm.Utils;
import com.newrelic.weave.UtilityClass;
import com.newrelic.weave.utils.WeaveUtils;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.commons.Method;

import java.security.ProtectionDomain;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.logging.Level;

/**
 * This class tracks information about a class passing through the {@link InstrumentationContextManager}. It keeps track
 * of the methods of the class that have been matched by different class transformers that are registered with the
 * manager.
 */
public class InstrumentationContext implements TraceDetailsList {

    private static final TraceInformation EMPTY_TRACE_INFO = new TraceInformation();

    protected final byte[] bytes;
    private boolean modified;
    private Multimap<Method, String> weavedMethods;
    protected boolean print;
    private Set<Method> timedMethods;
    private Map<Method, PointCut> oldReflectionStyleInstrumentationMethods;
    private Map<Method, PointCut> oldInvokerStyleInstrumentationMethods;
    private TraceInformation tracedInfo;
    private Map<ClassMatchVisitorFactory, Match> matches;
    private String[] interfaces;
    private String[] supers;
    private Map<Method, Method> bridgeMethods;
    private String className;
    private final Class<?> classBeingRedefined;
    private final ProtectionDomain protectionDomain;
    private List<ClassResolver> classResolvers;
    private boolean generated;
    private boolean hasSource;

    public InstrumentationContext(byte[] bytes, Class<?> classBeingRedefined, ProtectionDomain protectionDomain) {
        this.bytes = bytes;
        this.classBeingRedefined = classBeingRedefined;
        this.protectionDomain = protectionDomain;
        // TODO: load up supers and interfaces
    }

    public String[] getInterfaces() {
        return null == interfaces ? new String[0] : interfaces;
    }

    public String[] getSupers() {
        return null == supers ? new String[0] : supers;
    }

    public String getClassName() {
        return className;
    }

    public Class<?> getClassBeingRedefined() {
        return classBeingRedefined;
    }

    public ProtectionDomain getProtectionDomain() {
        return protectionDomain;
    }

    public void markAsModified() {
        this.modified = true;
    }

    public boolean isModified() {
        return modified;
    }

    public TraceInformation getTraceInformation() {
        return tracedInfo == null ? EMPTY_TRACE_INFO : tracedInfo;
    }

    public boolean isTracerMatch() {
        return (tracedInfo != null && tracedInfo.isMatch());
    }

    /**
     * Adds a weaved method.
     *
     * @param method
     * @param instrumentationTitle The name of the instrumentation package from which the weaved code originated.
     */
    public void addWeavedMethod(Method method, String instrumentationTitle) {
        if (weavedMethods == null) {
            weavedMethods = Multimaps.newSetMultimap(Maps.<Method, Collection<String>>newHashMap(),
                    new Supplier<Set<String>>() {
                        @Override
                        public Set<String> get() {
                            return Sets.newHashSet();
                        }
                    });
        }
        weavedMethods.put(method, instrumentationTitle);
        modified = true;
    }

    public PointCut getOldStylePointCut(Method method) {
        PointCut pc = getOldInvokerStyleInstrumentationMethods().get(method);
        if (pc == null) {
            pc = getOldReflectionStyleInstrumentationMethods().get(method);
        }
        return pc;
    }

    private Map<Method, PointCut> getOldInvokerStyleInstrumentationMethods() {
        return oldInvokerStyleInstrumentationMethods == null ? Collections.<Method, PointCut>emptyMap()
                : oldInvokerStyleInstrumentationMethods;
    }

    private Map<Method, PointCut> getOldReflectionStyleInstrumentationMethods() {
        return oldReflectionStyleInstrumentationMethods == null ? Collections.<Method, PointCut>emptyMap()
                : oldReflectionStyleInstrumentationMethods;
    }

    public Set<Method> getWeavedMethods() {
        return weavedMethods == null ? Collections.<Method>emptySet() : weavedMethods.keySet();
    }

    /**
     * Returns methods that are timed with instrumentation injected by the new {@link TraceClassVisitor} or the old
     * GenericClassAdapter.
     */
    public Set<Method> getTimedMethods() {
        return timedMethods == null ? Collections.<Method>emptySet() : timedMethods;
    }

    public Collection<String> getMergeInstrumentationPackages(Method method) {
        return weavedMethods == null ? Collections.<String>emptySet() : weavedMethods.asMap().get(method);
    }

    public boolean isModified(Method method) {
        return (getTimedMethods().contains(method)) || (getWeavedMethods().contains(method));
    }

    /**
     * Adds methods that are timed with method tracers.
     */
    public void addTimedMethods(Method... methods) {
        if (timedMethods == null) {
            timedMethods = Sets.newHashSet();
        }
        Collections.addAll(timedMethods, methods);
        modified = true;
    }

    public void addOldReflectionStyleInstrumentationMethod(Method method, PointCut pointCut) {
        if (oldReflectionStyleInstrumentationMethods == null) {
            oldReflectionStyleInstrumentationMethods = Maps.newHashMap();
        }
        oldReflectionStyleInstrumentationMethods.put(method, pointCut);
        modified = true;
    }

    public void addOldInvokerStyleInstrumentationMethod(Method method, PointCut pointCut) {
        if (oldInvokerStyleInstrumentationMethods == null) {
            oldInvokerStyleInstrumentationMethods = Maps.newHashMap();
        }
        oldInvokerStyleInstrumentationMethods.put(method, pointCut);
        modified = true;
    }

    public Map<ClassMatchVisitorFactory, Match> getMatches() {
        return matches == null ? Collections.<ClassMatchVisitorFactory, Match>emptyMap() : matches;
    }

    byte[] processTransformBytes(byte[] originalBytes, byte[] newBytes) {
        if (null != newBytes) {
            markAsModified();
            return newBytes;
        }
        return originalBytes;
    }

    public void putTraceAnnotation(Method method, TraceDetails traceDetails) {
        if (tracedInfo == null) {
            tracedInfo = new TraceInformation();
        }
        tracedInfo.putTraceAnnotation(method, traceDetails);
    }

    public void addIgnoreApdexMethod(String methodName, String methodDesc) {
        if (tracedInfo == null) {
            tracedInfo = new TraceInformation();
        }
        tracedInfo.addIgnoreApdexMethod(methodName, methodDesc);
    }

    public void addIgnoreTransactionMethod(String methodName, String methodDesc) {
        if (tracedInfo == null) {
            tracedInfo = new TraceInformation();
        }
        tracedInfo.addIgnoreTransactionMethod(methodName, methodDesc);
    }

    public void addIgnoreTransactionMethod(Method m) {
        if (tracedInfo == null) {
            tracedInfo = new TraceInformation();
        }
        tracedInfo.addIgnoreTransactionMethod(m);
    }

    public void putMatch(ClassMatchVisitorFactory matcher, Match match) {
        if (matches == null) {
            matches = Maps.newHashMap();
        }
        matches.put(matcher, match);
    }

    public void setInterfaces(String[] interfaces) {
        this.interfaces = interfaces;
    }

    public void setSupers(String[] superNames) {
        this.supers = superNames;
    }

    public void setClassName(String className) {
        this.className = className;
    }

    /**
     * Adds methods to be traced (timed) by instrumentation injected by the {@link TraceClassVisitor}.
     */
    public void addTracedMethods(Map<Method, TraceDetails> tracedMethods) {
        if (tracedInfo == null) {
            tracedInfo = new TraceInformation();
        }
        tracedInfo.pullAll(tracedMethods);
    }

    /**
     * Adds a method to be traced (timed) by instrumentation injected by the {@link TraceClassVisitor}.
     */
    @Override
    public void addTrace(Method method, TraceDetails traceDetails) {
        if (tracedInfo == null) {
            tracedInfo = new TraceInformation();
        }
        tracedInfo.putTraceAnnotation(method, traceDetails);
    }

    public void match(ClassLoader loader, Class<?> classBeingRedefined, ClassReader reader,
            Collection<ClassMatchVisitorFactory> classVisitorFactories) {

        ClassVisitor visitor = null;
        for (ClassMatchVisitorFactory factory : classVisitorFactories) {
            ClassVisitor nextVisitor = factory.newClassMatchVisitor(loader, classBeingRedefined, reader, visitor, this);
            if (nextVisitor != null) {
                visitor = nextVisitor;
            }
        }
        if (visitor != null) {
            reader.accept(visitor, ClassReader.SKIP_CODE);
            if (bridgeMethods != null) {
                // resolve bridge methods
                resolveBridgeMethods(reader);

            } else {
                bridgeMethods = ImmutableMap.<Method, Method>of();
            }
        }
    }

    /**
     * {@link ClassMatchVisitorFactory} implementations add bridge methods that they've matched to the
     * {@link #bridgeMethods} map. In that initial pass they just add the method but don't resolve the actual
     * implementation. In a second pass we visit the code to resolve the signature of the actual implementation.
     *
     * For example, if a class implements the generic {@link List} interface and specifies that the type is
     * {@link Integer}, the matchers will add the add(Object) method to our bridged method map, and this method will set
     * the value to the add(Integer) method which implements the add method.
     *
     * @param reader
     * @see Opcodes#ACC_BRIDGE
     */
    private void resolveBridgeMethods(ClassReader reader) {
        ClassVisitor visitor = new ClassVisitor(WeaveUtils.ASM_API_LEVEL) {

            @Override
            public MethodVisitor visitMethod(int access, String name, String desc, String signature,
                    String[] exceptions) {
                final Method method = new Method(name, desc);
                if (bridgeMethods.containsKey(method)) {
                    return new MethodVisitor(WeaveUtils.ASM_API_LEVEL) {

                        @Override
                        public void visitMethodInsn(int opcode, String owner, String name, String desc, boolean itf) {
                            bridgeMethods.put(method, new Method(name, desc));
                            super.visitMethodInsn(opcode, owner, name, desc, itf);
                        }

                    };

                }
                return null;
            }

        };

        reader.accept(visitor, ClassReader.SKIP_DEBUG + ClassReader.SKIP_FRAMES);
    }

    public static Set<Class<?>> getMatchingClasses(final Collection<ClassMatchVisitorFactory> matchers,
            Class<?>... classes) {
        final Set<Class<?>> matchingClasses = Sets.newConcurrentHashSet();
        if (classes == null || classes.length == 0) {
            return matchingClasses;
        }

        double partitions = classes.length < 8 ? classes.length : 8;
        int estimatedPerPartition = (int) Math.ceil(classes.length / partitions);
        List<List<Class<?>>> partitionsClasses = Lists.partition(Arrays.asList(classes), estimatedPerPartition);

        final CountDownLatch countDownLatch = new CountDownLatch(partitionsClasses.size());
        for (final List<Class<?>> partitionClasses : partitionsClasses) {
            Runnable matchingRunnable = new Runnable() {
                @Override
                public void run() {
                    try {
                        for (Class<?> clazz : partitionClasses) {
                            if (isMatch(matchers, clazz)) {
                                // FIXME maybe we should check for the Weaved annotation. skip interfaces
                                matchingClasses.add(clazz);
                            }
                        }
                    } finally {
                        countDownLatch.countDown();
                    }
                }
            };
            new Thread(matchingRunnable).start();
        }

        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Agent.LOG.log(Level.INFO, "Failed to wait for matching classes");
            Agent.LOG.log(Level.FINER, e, "Interrupted during class matching");
        }

        return matchingClasses;
    }

    private static boolean isMatch(Collection<ClassMatchVisitorFactory> matchers, Class<?> clazz) {
        if (clazz.isArray()) {
            // this one is funny. Apparently class arrays get cached. I assume we'll also see the regular class
            // and can ignore the array. The call to ClassLoader.getResource() for the array class will fail.
            return false;
        }
        String className = clazz.getName();
        if (className.startsWith("com.newrelic.api.agent") || className.startsWith("com.newrelic.agent.bridge") ||
                className.startsWith("com.newrelic.weave.") || className.startsWith("com.nr.agent") ||
                className.endsWith("_nr_ext") || className.endsWith("_nr_anon")) {
            return false;
        }

        ClassLoader loader = clazz.getClassLoader();
        if (loader == null) {
            loader = AgentBridge.getAgent().getClass().getClassLoader();
        }
        InstrumentationContext context = new InstrumentationContext(null, null, null);

        try {
            ClassReader reader = Utils.readClass(clazz);
            context.match(loader, clazz, reader, matchers);
            return !context.getMatches().isEmpty();
        } catch (BenignClassReadException ex) {
            return false;
        } catch (Exception ex) {
            // we often can't load our classes or lambdas because they're generated. Don't log this stuff.
            if (className.startsWith("com.newrelic") || className.startsWith("weave.")
                    || className.startsWith("com.nr.instrumentation") || className.startsWith("io.opentracing.")
                    || className.contains("$$Lambda$") || className.contains("LambdaForm$")
                    || className.contains("GeneratedConstructorAccessor") ||className.contains("GeneratedMethodAccessor")
                    || className.contains("BoundMethodHandle$")) {
                return false;
            }
            if (clazz.isAnnotationPresent(UtilityClass.class)) {
                return false;
            }
            Agent.LOG.log(Level.FINER, "Unable to read {0}", className);
            Agent.LOG.log(Level.FINEST, ex, "Unable to read {0}", className);
            return false;
        }
    }

    /**
     * Add a bridged method to this context.
     *
     * @param method
     * @see Opcodes#ACC_BRIDGE
     */
    public void addBridgeMethod(Method method) {
        if (bridgeMethods == null) {
            bridgeMethods = Maps.newHashMap();
        }
        bridgeMethods.put(method, method);
    }

    /**
     * Returns a map of bridge methods. The key is the generic method definition from the matchers, and the value is the
     * actual method implementation with generic types. For example, a class may implement List<T> with a specific type
     * of Person. The class will implement a method called add(Person) and the JVM will generate the bridge method
     * add(Object) which will simply invoke add(Person). Generally speaking, we want to be able to create matchers that
     * match the loosely typed version of the method (the signature that's usually a bridge method), but we want to
     * instrument the typed version of the method because it can be invoked directly without passing through the bridge
     * implementation.
     */
    public Map<Method, Method> getBridgeMethods() {
        return bridgeMethods;
    }

    public boolean isUsingLegacyInstrumentation() {
        return null != oldInvokerStyleInstrumentationMethods || null != oldReflectionStyleInstrumentationMethods;
    }

    public boolean hasModifiedClassStructure() {
        return null != oldInvokerStyleInstrumentationMethods;
    }

    /**
     * Adds a class resolver to the current context.
     *
     * @param classResolver
     */
    public void addClassResolver(ClassResolver classResolver) {
        if (this.classResolvers == null) {
            this.classResolvers = Lists.newArrayList();
        }
        this.classResolvers.add(classResolver);
    }

    /**
     * Returns a class resolver that will delegate to the class resolvers added with
     * {@link #addClassResolver(ClassResolver)}. If those fail to resolve the class the given classloader is used.
     *
     * @param loader
     * @return
     * @see ClassResolvers#getClassLoaderResolver(ClassLoader)
     */
    public ClassResolver getClassResolver(ClassLoader loader) {
        ClassResolver classResolver = ClassResolvers.getClassLoaderResolver(loader);
        if (classResolvers != null) {
            classResolvers.add(classResolver);
            classResolver = ClassResolvers.getMultiResolver(classResolvers);
        }
        return classResolver;
    }

    public void setGenerated(boolean isGenerated) {
        this.generated = isGenerated;
    }

    /**
     * Return true if the GeneratedClassDetector identified this class as a generated class.
     */
    public boolean isGenerated() {
        return generated;
    }

    public void setSourceAttribute(boolean hasSource) {
        this.hasSource = hasSource;
    }

    /**
     * Return true if the GeneratedClassDetector found that this class has a source attribute. Java class files are not
     * required to have a source attribute. When a class is created by a compiler, the source attribute generally
     * contains the name of the source file. When a class is generated by a bytecode tool, the attribute may contain
     * anything or may be absent.
     *
     * @return true if a source attribute was found on the class file.
     */
    public boolean hasSourceAttribute() {
        return hasSource;
    }
}
