package name.remal.reflection;

import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;

import static name.remal.ArrayUtils.contains;
import static name.remal.reflection.BootstrapClassLoader.getBootstrapClassLoader;
import static name.remal.reflection.ExtendedURLClassLoader.LoadingOrder.*;

public class ExtendedURLClassLoader extends URLClassLoader {

    static {
        ClassLoader.registerAsParallelCapable();
    }

    public enum LoadingOrder {
        PARENT_FIRST, THIS_FIRST, PARENT_ONLY, THIS_ONLY;
    }

    @FunctionalInterface
    public interface LoadingOrderFactory {
        @NotNull
        LoadingOrder getLoadingOrder(@NotNull String className);
    }

    @NotNull
    private final LoadingOrderFactory loadingOrderFactory;

    public ExtendedURLClassLoader(@NotNull LoadingOrderFactory loadingOrderFactory, @NotNull URL[] urls, @Nullable ClassLoader parent) {
        super(uniqueURLs(urls), null != parent ? parent : getSystemClassLoader());
        this.loadingOrderFactory = loadingOrderFactory;
    }

    public ExtendedURLClassLoader(@NotNull LoadingOrder loadingOrder, @NotNull URL[] urls, @Nullable ClassLoader parent) {
        this(__ -> loadingOrder, urls, parent);
    }

    public ExtendedURLClassLoader(@NotNull LoadingOrderFactory loadingOrderFactory, @NotNull Iterable<URL> urls, @Nullable ClassLoader parent) {
        this(loadingOrderFactory, iterableUrlsToArray(urls), parent);
    }

    public ExtendedURLClassLoader(@NotNull LoadingOrder loadingOrder, @NotNull Iterable<URL> urls, @Nullable ClassLoader parent) {
        this(__ -> loadingOrder, urls, parent);
    }

    public ExtendedURLClassLoader(@NotNull URL[] urls, @Nullable ClassLoader parent) {
        this(PARENT_FIRST, urls, parent);
    }

    public ExtendedURLClassLoader(@NotNull Iterable<URL> urls, @Nullable ClassLoader parent) {
        this(PARENT_FIRST, urls, parent);
    }

    public ExtendedURLClassLoader(@NotNull LoadingOrderFactory loadingOrderFactory, @NotNull URL[] urls) {
        this(loadingOrderFactory, urls, getSystemClassLoader());
    }

    public ExtendedURLClassLoader(@NotNull LoadingOrder loadingOrder, @NotNull URL[] urls) {
        this(loadingOrder, urls, getSystemClassLoader());
    }

    public ExtendedURLClassLoader(@NotNull LoadingOrderFactory loadingOrderFactory, @NotNull Iterable<URL> urls) {
        this(loadingOrderFactory, urls, getSystemClassLoader());
    }

    public ExtendedURLClassLoader(@NotNull LoadingOrder loadingOrder, @NotNull Iterable<URL> urls) {
        this(loadingOrder, urls, getSystemClassLoader());
    }

    public ExtendedURLClassLoader(@NotNull URL[] urls) {
        this(urls, getSystemClassLoader());
    }

    public ExtendedURLClassLoader(@NotNull Iterable<URL> urls) {
        this(urls, getSystemClassLoader());
    }

    @NotNull
    @Override
    public Class<?> loadClass(@NotNull String name) throws ClassNotFoundException {
        return loadClass(name, false);
    }

    @NotNull
    @Override
    protected Class<?> loadClass(@NotNull String name, boolean resolve) throws ClassNotFoundException {
        synchronized (getClassLoadingLock(name)) {
            Class<?> loadedClass = findLoadedClass(name);

            if (null == loadedClass) loadedClass = findBootstrapClassOrNull(name);

            if (null == loadedClass) {
                LoadingOrder loadingOrder = loadingOrderFactory.getLoadingOrder(name);
                if (PARENT_FIRST == loadingOrder) {
                    loadedClass = findParentClassOrNull(name);
                    if (null == loadedClass) loadedClass = findClassOrNull(name);
                } else if (THIS_FIRST == loadingOrder) {
                    loadedClass = findClassOrNull(name);
                    if (null == loadedClass) loadedClass = findParentClassOrNull(name);
                } else if (PARENT_ONLY == loadingOrder) {
                    loadedClass = findParentClassOrNull(name);
                } else if (THIS_ONLY == loadingOrder) {
                    loadedClass = findClassOrNull(name);
                } else {
                    throw new IllegalStateException("Unsupported " + LoadingOrder.class.getSimpleName() + ": " + loadingOrder);
                }
            }

            if (null == loadedClass) throw new ClassNotFoundException(name);

            if (resolve) resolveClass(loadedClass);

            return loadedClass;
        }
    }

    @Nullable
    protected Class<?> findBootstrapClassOrNull(@NotNull String className) {
        try {
            return getBootstrapClassLoader().loadClass(className);
        } catch (ClassNotFoundException e) {
            return null;
        }
    }

    @Nullable
    protected Class<?> findClassOrNull(@NotNull String className) {
        try {
            return findClass(className);
        } catch (ClassNotFoundException e) {
            return null;
        }
    }

    @Nullable
    protected Class<?> findParentClassOrNull(@NotNull String className) {
        ClassLoader parent = getParent();
        if (null == parent) return null;
        try {
            return parent.loadClass(className);
        } catch (ClassNotFoundException e) {
            return null;
        }
    }

    @Override
    public synchronized void addURL(@NotNull URL url) {
        if (!contains(getURLs(), url)) {
            super.addURL(url);
        }
    }

    @NotNull
    private static URL[] uniqueURLs(@NotNull URL[] urls) {
        return Stream.of(urls).distinct().toArray(URL[]::new);
    }

    @NotNull
    private static URL[] iterableUrlsToArray(@NotNull Iterable<URL> urls) {
        if (urls instanceof List) return ((List<URL>) urls).toArray(new URL[0]);
        List<URL> list = new ArrayList<>();
        for (URL url : urls) list.add(url);
        return list.toArray(new URL[0]);
    }

}
