/*
 * Decompiled with CFR 0.152.
 */
package jadx.core.utils;

import jadx.api.IDecompileScheduler;
import jadx.api.JadxDecompiler;
import jadx.api.JavaClass;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.utils.Utils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DecompilerScheduler
implements IDecompileScheduler {
    private static final Logger LOG = LoggerFactory.getLogger(DecompilerScheduler.class);
    private static final int MERGED_BATCH_SIZE = 16;
    private static final boolean DUMP_STATS = false;
    private final JadxDecompiler decompiler;

    public DecompilerScheduler(JadxDecompiler decompiler) {
        this.decompiler = decompiler;
    }

    @Override
    public List<List<JavaClass>> buildBatches(List<JavaClass> classes) {
        long start = System.currentTimeMillis();
        List<List<ClassNode>> batches = this.internalBatches(Utils.collectionMap(classes, JavaClass::getClassNode));
        List<List<JavaClass>> result = Utils.collectionMap(batches, l -> Utils.collectionMapNoNull(l, this.decompiler::getJavaClassByNode));
        if (LOG.isDebugEnabled()) {
            LOG.debug("Build decompilation batches in {}ms", (Object)(System.currentTimeMillis() - start));
        }
        return result;
    }

    public List<List<ClassNode>> internalBatches(List<ClassNode> classes) {
        HashMap<ClassNode, DepInfo> depsMap = new HashMap<ClassNode, DepInfo>(classes.size());
        HashSet<ClassNode> visited = new HashSet<ClassNode>();
        for (ClassNode classNode : classes) {
            visited.clear();
            this.sumDeps(classNode, depsMap, visited);
        }
        ArrayList deps = new ArrayList(depsMap.values());
        Collections.sort(deps);
        HashSet<ClassNode> added = new HashSet<ClassNode>(classes.size());
        Comparator<ClassNode> cmpDepSize = Comparator.comparingInt(c -> c.getDependencies().size());
        ArrayList<List<ClassNode>> result = new ArrayList<List<ClassNode>>();
        ArrayList<ClassNode> mergedBatch = new ArrayList<ClassNode>(16);
        for (DepInfo depInfo : deps) {
            ClassNode cls = depInfo.getCls();
            int depsSize = cls.getDependencies().size();
            if (depsSize == 0) {
                mergedBatch.add(cls);
                added.add(cls);
                if (mergedBatch.size() < 16) continue;
                result.add(mergedBatch);
                mergedBatch = new ArrayList(16);
                continue;
            }
            ArrayList<ClassNode> batch = new ArrayList<ClassNode>(depsSize + 1);
            for (ClassNode dep : cls.getDependencies()) {
                ClassNode topDep = dep.getTopParentClass();
                if (added.contains(topDep)) continue;
                batch.add(topDep);
            }
            batch.sort(cmpDepSize);
            batch.add(cls);
            added.addAll(batch);
            result.add(batch);
        }
        if (mergedBatch.size() > 0) {
            result.add(mergedBatch);
        }
        return result;
    }

    public int sumDeps(ClassNode cls, Map<ClassNode, DepInfo> depsMap, Set<ClassNode> visited) {
        visited.add(cls);
        DepInfo depInfo = depsMap.get(cls);
        if (depInfo != null) {
            return depInfo.getDepsCount();
        }
        List<ClassNode> deps = cls.getDependencies();
        int count = deps.size();
        for (ClassNode dep : deps) {
            if (visited.contains(dep)) continue;
            count += this.sumDeps(dep, depsMap, visited);
        }
        depsMap.put(cls, new DepInfo(cls, count));
        return count;
    }

    private void dumpBatchesStats(List<ClassNode> classes, List<List<ClassNode>> result, List<DepInfo> deps) {
        double avg = result.stream().mapToInt(List::size).average().orElse(-1.0);
        int maxSingleDeps = classes.stream().mapToInt(c -> c.getDependencies().size()).max().orElse(-1);
        int maxRecursiveDeps = deps.stream().mapToInt(DepInfo::getDepsCount).max().orElse(-1);
        LOG.info("Batches stats:\n input classes: " + classes.size() + ",\n batches: " + result.size() + ",\n average batch size: " + avg + ",\n max single deps count: " + maxSingleDeps + ",\n max recursive deps count: " + maxRecursiveDeps);
    }

    private static final class DepInfo
    implements Comparable<DepInfo> {
        private final ClassNode cls;
        private final int depsCount;

        private DepInfo(ClassNode cls, int depsCount) {
            this.cls = cls;
            this.depsCount = depsCount;
        }

        public ClassNode getCls() {
            return this.cls;
        }

        public int getDepsCount() {
            return this.depsCount;
        }

        @Override
        public int compareTo(@NotNull DepInfo o) {
            return Integer.compare(this.depsCount, o.depsCount);
        }
    }
}

