/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.learning.impl.elements;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.glove.AbstractCoOccurrences;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GloVe<T extends SequenceElement>
implements ElementsLearningAlgorithm<T> {
    private VocabCache<T> vocabCache;
    private AbstractCoOccurrences<T> coOccurrences;
    private WeightLookupTable<T> lookupTable;
    private VectorsConfiguration configuration;
    private AtomicBoolean isTerminate = new AtomicBoolean(false);
    private INDArray syn0;
    private double xMax;
    private boolean shuffle;
    private boolean symmetric;
    protected double alpha = 0.75;
    protected double learningRate = 0.0;
    protected int maxmemory = 0;
    protected int batchSize = 1000;
    private AdaGrad weightAdaGrad;
    private AdaGrad biasAdaGrad;
    private INDArray bias;
    private int workers = Runtime.getRuntime().availableProcessors();
    private int vectorLength;
    private static final Logger log = LoggerFactory.getLogger(GloVe.class);

    @Override
    public String getCodeName() {
        return "GloVe";
    }

    @Override
    public void finish() {
        log.info("GloVe finalizer...");
    }

    @Override
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache");
        }
        if (lookupTable == null) {
            throw new NullPointerException("lookupTable");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = lookupTable;
        this.configuration = configuration;
        this.syn0 = ((InMemoryLookupTable)lookupTable).getSyn0();
        this.vectorLength = configuration.getLayersSize();
        if (this.learningRate == 0.0) {
            this.learningRate = configuration.getLearningRate();
        }
        this.weightAdaGrad = new AdaGrad(new int[]{this.vocabCache.numWords() + 1, this.vectorLength}, this.learningRate);
        this.bias = Nd4j.create((int)this.syn0.rows());
        this.biasAdaGrad = new AdaGrad(this.bias.shape(), this.learningRate);
        log.info("GloVe params: {Max Memory: [" + this.maxmemory + "], Learning rate: [" + this.learningRate + "], Alpha: [" + this.alpha + "], xMax: [" + this.xMax + "], Symmetric: [" + this.symmetric + "], Shuffle: [" + this.shuffle + "]}");
    }

    @Override
    public void pretrain(@NonNull SequenceIterator<T> iterator) {
        if (iterator == null) {
            throw new NullPointerException("iterator");
        }
        this.coOccurrences = new AbstractCoOccurrences.Builder().symmetric(this.symmetric).windowSize(this.configuration.getWindow()).iterate(iterator).workers(this.workers).vocabCache(this.vocabCache).maxMemory(this.maxmemory).build();
        this.coOccurrences.fit();
    }

    @Override
    public synchronized double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom, double learningRate) {
        if (sequence == null) {
            throw new NullPointerException("sequence");
        }
        if (nextRandom == null) {
            throw new NullPointerException("nextRandom");
        }
        if (this.isTerminate.get()) {
            return 0.0;
        }
        AtomicLong pairsCount = new AtomicLong(0L);
        Counter errorCounter = new Counter();
        for (int i = 0; i < this.configuration.getEpochs(); ++i) {
            int x;
            Iterator<Pair<Pair<T, T>, Double>> pairs = this.coOccurrences.iterator();
            ArrayList<GloveCalculationsThread> threads = new ArrayList<GloveCalculationsThread>();
            for (x = 0; x < this.workers; ++x) {
                threads.add(x, new GloveCalculationsThread(i, x, pairs, pairsCount, (Counter<Integer>)errorCounter));
                ((GloveCalculationsThread)threads.get(x)).start();
            }
            for (x = 0; x < this.workers; ++x) {
                try {
                    ((GloveCalculationsThread)threads.get(x)).join();
                    continue;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            log.info("Processed [" + pairsCount.get() + "] pairs, Error was [" + errorCounter.getCount((Object)i) + "]");
        }
        this.isTerminate.set(true);
        return 0.0;
    }

    @Override
    public synchronized boolean isEarlyTerminationHit() {
        return this.isTerminate.get();
    }

    private double iterateSample(T element1, T element2, double score) {
        double fDiff;
        if (((SequenceElement)element1).getIndex() < 0 || ((SequenceElement)element1).getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + ((SequenceElement)element1).getLabel());
        }
        if (((SequenceElement)element2).getIndex() < 0 || ((SequenceElement)element2).getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + ((SequenceElement)element2).getLabel());
        }
        INDArray w1Vector = this.syn0.slice(((SequenceElement)element1).getIndex());
        INDArray w2Vector = this.syn0.slice(((SequenceElement)element2).getIndex());
        double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
        double d = fDiff = score > this.xMax ? prediction : Math.pow(score / this.xMax, this.alpha) * (prediction += this.bias.getDouble(((SequenceElement)element1).getIndex()) + this.bias.getDouble(((SequenceElement)element2).getIndex()) - Math.log(score));
        if (Double.isNaN(fDiff)) {
            fDiff = Nd4j.EPS_THRESHOLD;
        }
        double gradient = fDiff * this.learningRate;
        this.update(element1, w1Vector, w2Vector, gradient);
        this.update(element2, w2Vector, w1Vector, gradient);
        return 0.5 * fDiff * prediction;
    }

    private void update(T element1, INDArray wordVector, INDArray contextVector, double gradient) {
        INDArray grad1 = contextVector.mul((Number)gradient);
        INDArray update = this.weightAdaGrad.getGradient(grad1, ((SequenceElement)element1).getIndex(), this.syn0.shape());
        wordVector.subi(update);
        double w1Bias = this.bias.getDouble(((SequenceElement)element1).getIndex());
        double biasGradient = this.biasAdaGrad.getGradient(gradient, ((SequenceElement)element1).getIndex(), this.bias.shape());
        double update2 = w1Bias - biasGradient;
        this.bias.putScalar(((SequenceElement)element1).getIndex(), update2);
    }

    public static class Builder<T extends SequenceElement> {
        protected double xMax = 100.0;
        protected double alpha = 0.75;
        protected double learningRate = 0.0;
        protected boolean shuffle = false;
        protected boolean symmetric = false;
        protected int maxmemory = 0;
        protected int batchSize = 1000;

        public Builder<T> batchSize(int batchSize) {
            this.batchSize = batchSize;
            return this;
        }

        public Builder<T> learningRate(double eta) {
            this.learningRate = eta;
            return this;
        }

        public Builder<T> alpha(double alpha) {
            this.alpha = alpha;
            return this;
        }

        public Builder<T> maxMemory(int gbytes) {
            this.maxmemory = gbytes;
            return this;
        }

        public Builder<T> xMax(double xMax) {
            this.xMax = xMax;
            return this;
        }

        public Builder<T> shuffle(boolean reallyShuffle) {
            this.shuffle = reallyShuffle;
            return this;
        }

        public Builder<T> symmetric(boolean reallySymmetric) {
            this.symmetric = reallySymmetric;
            return this;
        }

        public GloVe<T> build() {
            GloVe ret = new GloVe();
            ret.symmetric = this.symmetric;
            ret.shuffle = this.shuffle;
            ret.xMax = this.xMax;
            ret.alpha = this.alpha;
            ret.learningRate = this.learningRate;
            ret.maxmemory = this.maxmemory;
            ret.batchSize = this.batchSize;
            return ret;
        }
    }

    private class GloveCalculationsThread
    extends Thread
    implements Runnable {
        private final int threadId;
        private final int epochId;
        private final Iterator<Pair<Pair<T, T>, Double>> coList;
        private final AtomicLong pairsCounter;
        private final Counter<Integer> errorCounter;

        public GloveCalculationsThread(int epochId, @NonNull int threadId, @NonNull Iterator<Pair<Pair<T, T>, Double>> pairs, @NonNull AtomicLong pairsCounter, Counter<Integer> errorCounter) {
            if (pairs == null) {
                throw new NullPointerException("pairs");
            }
            if (pairsCounter == null) {
                throw new NullPointerException("pairsCounter");
            }
            if (errorCounter == null) {
                throw new NullPointerException("errorCounter");
            }
            this.epochId = epochId;
            this.threadId = threadId;
            this.pairsCounter = pairsCounter;
            this.errorCounter = errorCounter;
            this.coList = pairs;
            this.setName("GloVe ELA t." + this.threadId);
        }

        @Override
        public void run() {
            while (this.coList.hasNext()) {
                ArrayList pairs = new ArrayList();
                for (int cnt = 0; this.coList.hasNext() && cnt < GloVe.this.batchSize; ++cnt) {
                    pairs.add(this.coList.next());
                }
                if (GloVe.this.shuffle) {
                    Collections.shuffle(pairs);
                }
                for (Pair pair : pairs) {
                    SequenceElement element1 = (SequenceElement)((Pair)pair.getFirst()).getFirst();
                    SequenceElement element2 = (SequenceElement)((Pair)pair.getFirst()).getSecond();
                    double weight = (Double)pair.getSecond();
                    if (weight <= 0.0) {
                        this.pairsCounter.incrementAndGet();
                        continue;
                    }
                    this.errorCounter.incrementCount((Object)this.epochId, GloVe.this.iterateSample(element1, element2, weight));
                    if (this.pairsCounter.incrementAndGet() % 1000000L != 0L) continue;
                    log.info("Processed [" + this.pairsCounter.get() + "] word pairs so far...");
                }
            }
        }
    }
}

