/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.scaleout.perform.models.word2vec;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.nn.conf.Configuration;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.scaleout.perform.WorkerPerformer;
import org.deeplearning4j.scaleout.perform.models.word2vec.Word2VecJobAggregator;
import org.deeplearning4j.scaleout.perform.models.word2vec.Word2VecPerformerFactory;
import org.deeplearning4j.scaleout.perform.models.word2vec.Word2VecResult;
import org.deeplearning4j.scaleout.perform.models.word2vec.Word2VecWork;
import org.deeplearning4j.scaleout.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Word2VecPerformer
implements WorkerPerformer {
    private int vectorLength = 50;
    public static final String NAME_SPACE = "org.deeplearning4j.scaleout.perform.models.word2vec";
    public static final String VECTOR_LENGTH = "org.deeplearning4j.scaleout.perform.models.word2vec.length";
    public static final String ADAGRAD = "org.deeplearning4j.scaleout.perform.models.word2vec.adagrad";
    public static final String NEGATIVE = "org.deeplearning4j.scaleout.perform.models.word2vec.negative";
    public static final String NUM_WORDS = "org.deeplearning4j.scaleout.perform.models.word2vec.numwords";
    public static final String TABLE = "org.deeplearning4j.scaleout.perform.models.word2vec.table";
    public static final String WINDOW = "org.deeplearning4j.scaleout.perform.models.word2vec.window";
    public static final String ALPHA = "org.deeplearning4j.scaleout.perform.models.word2vec.alpha";
    public static final String MIN_ALPHA = "org.deeplearning4j.scaleout.perform.models.word2vec.minalpha";
    public static final String TOTAL_WORDS = "org.deeplearning4j.scaleout.perform.models.word2vec.totalwords";
    public static final String NUM_WORDS_SO_FAR = "org.deeplearning4j.scaleout.perform.models.word2vec.wordssofar";
    public static final String ITERATIONS = "org.deeplearning4j.scaleout.perform.models.word2vec.iterations";
    double[] expTable = new double[1000];
    static double MAX_EXP = 6.0;
    private boolean useAdaGrad = false;
    private double negative = 5.0;
    private int numWords = 1;
    private INDArray table;
    private int window = 5;
    private AtomicLong nextRandom = new AtomicLong(5L);
    private double alpha = 0.025;
    private double minAlpha = 0.01;
    private int totalWords = 1;
    private int iterations = 5;
    private StateTracker stateTracker;
    private static Logger log = LoggerFactory.getLogger(Word2VecPerformer.class);
    private int lastChecked = 0;

    public Word2VecPerformer(StateTracker stateTracker) {
        this.stateTracker = stateTracker;
    }

    public Word2VecPerformer() {
    }

    public void perform(Job job) {
        if (job.getWork() instanceof Word2VecWork) {
            double numWordsSoFar = this.stateTracker.count(NUM_WORDS_SO_FAR);
            Word2VecWork work = (Word2VecWork)job.getWork();
            if (work == null) {
                return;
            }
            List<List<VocabWord>> sentences = work.getSentences();
            double alpha2 = Math.max(this.minAlpha, this.alpha * (1.0 - 1.0 * numWordsSoFar / (double)this.totalWords));
            int totalNewWords = 0;
            for (List<VocabWord> sentence : sentences) {
                for (int i = 0; i < this.iterations; ++i) {
                    this.trainSentence(sentence, work, alpha2);
                }
                totalNewWords += sentence.size();
            }
            double newWords = (double)totalNewWords + numWordsSoFar;
            double diff = Math.abs(newWords - (double)this.lastChecked);
            if (diff >= 10000.0) {
                this.lastChecked = (int)newWords;
                log.info("Words so far " + newWords + " out of " + this.totalWords);
            }
            job.setResult((Serializable)((Object)Arrays.asList(work.addDeltas())));
            this.stateTracker.increment(NUM_WORDS_SO_FAR, (double)totalNewWords);
        } else if (job.getWork() instanceof Collection) {
            double numWordsSoFar = this.stateTracker.count(NUM_WORDS_SO_FAR);
            Collection coll = (Collection)((Object)job.getWork());
            double alpha2 = Math.max(this.minAlpha, this.alpha * (1.0 - 1.0 * numWordsSoFar / (double)this.totalWords));
            int totalNewWords = 0;
            ArrayList<Word2VecResult> deltas = new ArrayList<Word2VecResult>();
            for (Word2VecWork work : coll) {
                List<List<VocabWord>> sentences = work.getSentences();
                for (List<VocabWord> sentence : sentences) {
                    this.trainSentence(sentence, work, alpha2);
                    totalNewWords += sentence.size();
                    deltas.add(work.addDeltas());
                }
            }
            double newWords = (double)totalNewWords + numWordsSoFar;
            double diff = Math.abs(newWords - (double)this.lastChecked);
            if (diff >= 10000.0) {
                this.lastChecked = (int)newWords;
                log.info("Words so far " + newWords + " out of " + this.totalWords);
            }
            job.setResult((Serializable)deltas);
            this.stateTracker.increment(NUM_WORDS_SO_FAR, (double)totalNewWords);
        }
    }

    public void update(Object ... o) {
    }

    public void setup(Configuration conf) {
        this.vectorLength = conf.getInt(VECTOR_LENGTH, 50);
        this.useAdaGrad = conf.getBoolean(ADAGRAD, false);
        this.negative = conf.getFloat(NEGATIVE, 5.0f);
        this.numWords = conf.getInt(NUM_WORDS, 1);
        this.window = conf.getInt(WINDOW, 5);
        this.alpha = conf.getFloat(ALPHA, 0.025f);
        this.minAlpha = conf.getFloat(MIN_ALPHA, 0.01f);
        this.totalWords = conf.getInt(NUM_WORDS, 1);
        this.iterations = conf.getInt(ITERATIONS, 5);
        this.initExpTable();
        String connectionString = conf.get("org.deeplearning4j.scaleout.statetracker.connectionstring");
        log.info("Creating state tracker with connection string " + connectionString);
        if (this.stateTracker == null) {
            try {
                this.stateTracker = new HazelCastStateTracker(connectionString);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        if (this.negative > 0.0) {
            try {
                ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes());
                DataInputStream dis = new DataInputStream(bis);
                this.table = Nd4j.read((DataInputStream)dis);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public static void configure(InMemoryLookupTable table, InvertedIndex index, Configuration conf) {
        conf.setInt(VECTOR_LENGTH, table.getVectorLength());
        conf.setBoolean(ADAGRAD, table.isUseAdaGrad());
        conf.setFloat(NEGATIVE, (float)table.getNegative());
        conf.setFloat(ALPHA, (float)table.getLr().get());
        conf.setInt(NUM_WORDS, index.totalWords());
        conf.set("org.deeplearning4j.scaleout.aggregator", Word2VecJobAggregator.class.getName());
        conf.set("org.deeplearning4j.scaleout.perform.workerperformer", Word2VecPerformerFactory.class.getName());
        table.resetWeights();
        if (table.getNegative() > 0.0) {
            ByteArrayOutputStream bis = new ByteArrayOutputStream();
            try {
                DataOutputStream ois = new DataOutputStream(bis);
                Nd4j.write((INDArray)table.getTable(), (DataOutputStream)ois);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            conf.set(TABLE, new String(bis.toByteArray()));
        }
    }

    public void trainSentence(List<VocabWord> sentence, Word2VecWork work, double alpha) {
        if (sentence == null || sentence.isEmpty()) {
            return;
        }
        for (int i = 0; i < sentence.size(); ++i) {
            if (sentence.get(i).getWord().endsWith("STOP")) continue;
            this.nextRandom.set(this.nextRandom.get() * 25214903917L + 11L);
            this.skipGram(i, sentence, (int)this.nextRandom.get() % this.window, work, alpha);
        }
    }

    public void skipGram(int i, List<VocabWord> sentence, int b, Word2VecWork work, double alpha) {
        VocabWord word = sentence.get(i);
        if (word == null || sentence.isEmpty()) {
            return;
        }
        int end = this.window * 2 + 1 - b;
        for (int a = b; a < end; ++a) {
            int c;
            if (a == this.window || (c = i - this.window + a) < 0 || c >= sentence.size()) continue;
            VocabWord lastWord = sentence.get(c);
            this.iterateSample(work, word, lastWord, alpha);
        }
    }

    public void iterateSample(Word2VecWork work, VocabWord w1, VocabWord w2, double alpha) {
        INDArray neu1e;
        INDArray l1;
        block21: {
            if (w2 == null || w2.getIndex() < 0) {
                return;
            }
            if (work.getVectors().get(w2.getWord()) == null) {
                log.warn("No vector found for word " + w2.getWord());
                return;
            }
            if (work.getVectors().get(w1.getWord()) == null) {
                log.warn("No vector found for word " + w1.getWord());
                return;
            }
            l1 = (INDArray)work.getVectors().get(w2.getWord()).getSecond();
            neu1e = Nd4j.create((int)this.vectorLength);
            for (int i = 0; i < w1.getCodeLength(); ++i) {
                int idx;
                int code = w1.getCodes().get(i);
                int point = w1.getPoints().get(i);
                if (work.getIndexes().get(point) == null) continue;
                if (work.getSyn1Vectors().get(work.getIndexes().get(point).getWord()) == null) {
                    log.warn("Syn1 vectors for " + work.getIndexes().get(point).getWord() + " was null");
                    continue;
                }
                INDArray syn1 = work.getSyn1Vectors().get(work.getIndexes().get(point).getWord());
                double dot = Nd4j.getBlasWrapper().dot(l1, syn1);
                if (dot < -MAX_EXP || dot >= MAX_EXP || (idx = (int)((dot + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))) >= this.expTable.length) continue;
                double f = this.expTable[idx];
                double g = ((double)(1 - code) - f) * (this.useAdaGrad ? w1.getGradient(i, alpha) : alpha);
                if (neu1e.data().dataType() == 0) {
                    Nd4j.getBlasWrapper().axpy(g, syn1, neu1e);
                    Nd4j.getBlasWrapper().axpy(g, l1, syn1);
                    continue;
                }
                Nd4j.getBlasWrapper().axpy((float)g, syn1, neu1e);
                Nd4j.getBlasWrapper().axpy((float)g, l1, syn1);
            }
            if (!(this.negative > 0.0)) break block21;
            int target = w1.getIndex();
            INDArray syn1Neg = (INDArray)work.getNegativeVectors().get(work.getIndexes().get(target).getWord()).getSecond();
            int d = 0;
            while ((double)d < this.negative + 1.0) {
                block24: {
                    double g;
                    int label;
                    block23: {
                        block22: {
                            if (d != 0) break block22;
                            label = 1;
                            break block23;
                        }
                        this.nextRandom.set(this.nextRandom.get() * 25214903917L + 11L);
                        target = this.table.getInt(new int[]{(int)(this.nextRandom.get() >> 16) % this.table.length()});
                        if (target == 0) {
                            target = (int)this.nextRandom.get() % (this.numWords - 1) + 1;
                        }
                        if (target == w1.getIndex()) break block24;
                        label = 0;
                    }
                    double f = Nd4j.getBlasWrapper().dot(l1, syn1Neg);
                    if (f > MAX_EXP) {
                        g = this.useAdaGrad ? w1.getGradient(target, label - 1) : (double)(label - 1) * alpha;
                    } else if (f < -MAX_EXP) {
                        g = (double)(label - 0) * (this.useAdaGrad ? w1.getGradient(target, alpha) : alpha);
                    } else {
                        double d2 = g = this.useAdaGrad ? w1.getGradient(target, (double)label - this.expTable[(int)((f + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))]) : ((double)label - this.expTable[(int)((f + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))]) * alpha;
                    }
                    if (syn1Neg.data().dataType() == 0) {
                        Nd4j.getBlasWrapper().axpy(g, neu1e, l1);
                    } else {
                        Nd4j.getBlasWrapper().axpy((float)g, neu1e, l1);
                    }
                    if (syn1Neg.data().dataType() == 0) {
                        Nd4j.getBlasWrapper().axpy(g, syn1Neg, l1);
                    } else {
                        Nd4j.getBlasWrapper().axpy((float)g, syn1Neg, l1);
                    }
                }
                ++d;
            }
        }
        if (neu1e.data().dataType() == 0) {
            Nd4j.getBlasWrapper().axpy(1.0, neu1e, l1);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, neu1e, l1);
        }
    }

    private void initExpTable() {
        for (int i = 0; i < this.expTable.length; ++i) {
            double tmp = FastMath.exp((double)(((double)i / (double)this.expTable.length * 2.0 - 1.0) * MAX_EXP));
            this.expTable[i] = tmp / (tmp + 1.0);
        }
    }
}

