/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.sequencevectors.graph.walkers.impl;

import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.graph.enums.SamplingMode;
import org.deeplearning4j.models.sequencevectors.graph.primitives.IGraph;
import org.deeplearning4j.models.sequencevectors.graph.primitives.Vertex;
import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NearestVertexWalker<V extends SequenceElement>
implements GraphWalker<V> {
    private static final Logger log = LoggerFactory.getLogger(NearestVertexWalker.class);
    protected IGraph<V, ?> sourceGraph;
    protected int walkLength = 0;
    protected long seed = 0L;
    protected SamplingMode samplingMode = SamplingMode.RANDOM;
    protected int[] order;
    protected Random rng;
    protected int depth;
    private AtomicInteger position = new AtomicInteger(0);

    protected NearestVertexWalker() {
    }

    @Override
    public boolean hasNext() {
        return this.position.get() < this.order.length;
    }

    @Override
    public Sequence<V> next() {
        return this.walk(this.sourceGraph.getVertex(this.order[this.position.getAndIncrement()]), 1);
    }

    @Override
    public void reset(boolean shuffle) {
        this.position.set(0);
        if (shuffle) {
            log.debug("Calling shuffle() on entries...");
            for (int i = this.order.length - 1; i > 0; --i) {
                int j = this.rng.nextInt(i + 1);
                int temp = this.order[j];
                this.order[j] = this.order[i];
                this.order[i] = temp;
            }
        }
    }

    protected Sequence<V> walk(Vertex<V> node, int cDepth) {
        Sequence<Object> sequence = new Sequence<Object>();
        int idx = node.vertexID();
        List<Vertex<V>> vertices = this.sourceGraph.getConnectedVertices(idx);
        sequence.setSequenceLabel(node.getValue());
        if (this.walkLength == 0) {
            for (Vertex<V> vertex : vertices) {
                sequence.addElement(vertex.getValue());
            }
        } else {
            switch (this.samplingMode) {
                case MAX_POPULARITY: {
                    int i;
                    Collections.sort(vertices, new VertexComparator(this.sourceGraph));
                    for (i = 0; i < this.walkLength; ++i) {
                        sequence.addElement(vertices.get(i).getValue());
                        if (this.depth <= 1 || cDepth >= this.depth) continue;
                        Sequence<V> nextDepth = this.walk(vertices.get(i), ++cDepth);
                        for (SequenceElement element : nextDepth.getElements()) {
                            if (sequence.getElementByLabel(element.getLabel()) != null) continue;
                            sequence.addElement(element);
                        }
                    }
                }
                case MEDIAN_POPULARITY: {
                    Sequence<V> nextDepth;
                    int e;
                    Collections.sort(vertices, new VertexComparator(this.sourceGraph));
                    int i = vertices.size() / 2 - this.walkLength / 2;
                    for (e = 0; e < this.walkLength && i < vertices.size(); ++i, ++e) {
                        sequence.addElement(vertices.get(i).getValue());
                        if (this.depth <= 1 || cDepth >= this.depth) continue;
                        nextDepth = this.walk(vertices.get(i), ++cDepth);
                        for (SequenceElement element : nextDepth.getElements()) {
                            if (sequence.getElementByLabel(element.getLabel()) != null) continue;
                            sequence.addElement(element);
                        }
                    }
                }
                case MIN_POPULARITY: {
                    Sequence<V> nextDepth;
                    int e;
                    Collections.sort(vertices, new VertexComparator(this.sourceGraph));
                    int i = vertices.size();
                    for (e = 0; e < this.walkLength && i >= 0; --i, ++e) {
                        sequence.addElement(vertices.get(i).getValue());
                        if (this.depth <= 1 || cDepth >= this.depth) continue;
                        nextDepth = this.walk(vertices.get(i), ++cDepth);
                        for (SequenceElement element : nextDepth.getElements()) {
                            if (sequence.getElementByLabel(element.getLabel()) != null) continue;
                            sequence.addElement(element);
                        }
                    }
                }
                case RANDOM: {
                    Sequence<V> nextDepth;
                    if (vertices.size() <= this.walkLength) {
                        for (Vertex<V> vertex : vertices) {
                            sequence.addElement(vertex.getValue());
                        }
                    } else {
                        HashSet elements = new HashSet();
                        while (elements.size() < this.walkLength) {
                            Vertex vertex = (Vertex)ArrayUtil.getRandomElement(vertices);
                            elements.add(vertex.getValue());
                            if (this.depth <= 1 || cDepth >= this.depth) continue;
                            nextDepth = this.walk(vertex, ++cDepth);
                            for (SequenceElement element : nextDepth.getElements()) {
                                if (sequence.getElementByLabel(element.getLabel()) != null) continue;
                                sequence.addElement(element);
                            }
                        }
                        sequence.addElements(elements);
                    }
                    break;
                }
                default: {
                    throw new ND4JIllegalStateException("Unknown sampling mode was passed in: [" + (Object)((Object)this.samplingMode) + "]");
                }
            }
        }
        return sequence;
    }

    @Override
    public boolean isLabelEnabled() {
        return true;
    }

    @Override
    public IGraph<V, ?> getSourceGraph() {
        return this.sourceGraph;
    }

    protected class VertexComparator<V extends SequenceElement, E extends Number>
    implements Comparator<Vertex<V>> {
        private IGraph<V, E> graph;

        public VertexComparator(IGraph<V, E> graph) {
            if (graph == null) {
                throw new NullPointerException("graph");
            }
            this.graph = graph;
        }

        @Override
        public int compare(Vertex<V> o1, Vertex<V> o2) {
            return Integer.compare(this.graph.getConnectedVertices(o2.vertexID()).size(), this.graph.getConnectedVertices(o1.vertexID()).size());
        }
    }

    public static class Builder<V extends SequenceElement> {
        protected int walkLength = 0;
        protected IGraph<V, ?> sourceGraph;
        protected SamplingMode samplingMode = SamplingMode.RANDOM;
        protected long seed;
        protected int depth = 1;

        public Builder(@NonNull IGraph<V, ?> graph) {
            if (graph == null) {
                throw new NullPointerException("graph");
            }
            this.sourceGraph = graph;
        }

        public Builder setSeed(long seed) {
            this.seed = seed;
            return this;
        }

        public Builder setWalkLength(int length) {
            this.walkLength = length;
            return this;
        }

        public Builder setDepth(int depth) {
            this.depth = depth;
            return this;
        }

        public Builder setSamplingMode(@NonNull SamplingMode mode) {
            if (mode == null) {
                throw new NullPointerException("mode");
            }
            this.samplingMode = mode;
            return this;
        }

        public NearestVertexWalker<V> build() {
            NearestVertexWalker walker = new NearestVertexWalker();
            walker.sourceGraph = this.sourceGraph;
            walker.walkLength = this.walkLength;
            walker.samplingMode = this.samplingMode;
            walker.depth = this.depth;
            walker.order = new int[this.sourceGraph.numVertices()];
            for (int i = 0; i < walker.order.length; ++i) {
                walker.order[i] = i;
            }
            walker.rng = new Random(this.seed);
            walker.reset(true);
            return walker;
        }
    }
}

