/*
 * Decompiled with CFR 0.152.
 */
package hex.word2vec;

import hex.word2vec.HBWTree;
import hex.word2vec.Word2Vec;
import hex.word2vec.Word2VecModel;
import java.util.Iterator;
import water.DKV;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.IcedHashMapGeneric;
import water.util.IcedLong;

public class WordVectorTrainer
extends MRTask<WordVectorTrainer> {
    private static final int MAX_SENTENCE_LEN = 1000;
    private static final int EXP_TABLE_SIZE = 1000;
    private static final int MAX_EXP = 6;
    private static final float[] _expTable = WordVectorTrainer.calcExpTable();
    private static final float LEARNING_RATE_MIN_FACTOR = 1.0E-4f;
    private final Job<Word2VecModel> _job;
    private final Word2Vec.WordModel _wordModel;
    private final int _wordVecSize;
    private final int _windowSize;
    private final int _epochs;
    private final float _initLearningRate;
    private final float _sentSampleRate;
    private final long _vocabWordCount;
    private final Key<Word2VecModel.Vocabulary> _vocabKey;
    private final Key<Word2VecModel.WordCounts> _wordCountsKey;
    private final Key<HBWTree> _treeKey;
    private final long _prevTotalProcessedWords;
    float[] _syn0;
    float[] _syn1;
    long _processedWords = 0L;
    IcedLong _nodeProcessedWords;
    private transient IcedHashMapGeneric<BufferedString, Integer> _vocab;
    private transient IcedHashMap<BufferedString, IcedLong> _wordCounts;
    private transient int[][] _HBWTCode;
    private transient int[][] _HBWTPoint;
    private float _curLearningRate;
    private long _seed = System.nanoTime();

    public WordVectorTrainer(Job<Word2VecModel> job, Word2VecModel.Word2VecModelInfo input) {
        super(null);
        this._job = job;
        this._treeKey = input._treeKey;
        this._vocabKey = input._vocabKey;
        this._wordCountsKey = input._wordCountsKey;
        this._wordModel = input.getParams()._word_model;
        this._wordVecSize = input.getParams()._vec_size;
        this._windowSize = input.getParams()._window_size;
        this._sentSampleRate = input.getParams()._sent_sample_rate;
        this._epochs = input.getParams()._epochs;
        this._initLearningRate = input.getParams()._init_learning_rate;
        this._vocabWordCount = input._vocabWordCount;
        this._prevTotalProcessedWords = input._totalProcessedWords;
        this._syn0 = input._syn0;
        this._syn1 = input._syn1;
        this._curLearningRate = WordVectorTrainer.calcLearningRate(this._initLearningRate, this._epochs, this._prevTotalProcessedWords, this._vocabWordCount);
    }

    @Override
    protected void setupLocal() {
        this._vocab = ((Word2VecModel.Vocabulary)DKV.getGet(this._vocabKey))._data;
        this._wordCounts = ((Word2VecModel.WordCounts)DKV.getGet(this._wordCountsKey))._data;
        HBWTree t = (HBWTree)DKV.getGet(this._treeKey);
        this._HBWTCode = t._code;
        this._HBWTPoint = t._point;
        this._nodeProcessedWords = new IcedLong(0L);
    }

    private static float[] calcExpTable() {
        float[] expTable = new float[1000];
        for (int i = 0; i < 1000; ++i) {
            expTable[i] = (float)Math.exp(((float)i / 1000.0f * 2.0f - 1.0f) * 6.0f);
            expTable[i] = expTable[i] / (expTable[i] + 1.0f);
        }
        return expTable;
    }

    @Override
    public void map(Chunk chk) {
        int winSize = this._windowSize;
        int vecSize = this._wordVecSize;
        float[] neu1 = new float[vecSize];
        float[] neu1e = new float[vecSize];
        ChunkSentenceIterator sentIter = new ChunkSentenceIterator(chk);
        int wordCount = 0;
        while (sentIter.hasNext()) {
            int sentLen = sentIter.nextLength();
            int[] sentence = sentIter.next();
            for (int sentIdx = 0; sentIdx < sentLen; ++sentIdx) {
                int winSizeMod;
                int curWord = sentence[sentIdx];
                int bagSize = 0;
                if (this._wordModel == Word2Vec.WordModel.CBOW) {
                    int j;
                    for (j = 0; j < vecSize; ++j) {
                        neu1[j] = 0.0f;
                    }
                    for (j = 0; j < vecSize; ++j) {
                        neu1e[j] = 0.0f;
                    }
                }
                for (int winIdx = winSizeMod = this.cheapRandInt(winSize); winIdx < winSize * 2 + 1 - winSizeMod; ++winIdx) {
                    int winWordSentIdx;
                    if (winIdx == winSize || (winWordSentIdx = sentIdx - winSize + winIdx) < 0 || winWordSentIdx >= sentLen) continue;
                    int winWord = sentence[winWordSentIdx];
                    if (this._wordModel == Word2Vec.WordModel.SkipGram) {
                        this.skipGram(curWord, winWord, neu1e);
                        continue;
                    }
                    for (int j = 0; j < vecSize; ++j) {
                        int n = j;
                        neu1[n] = neu1[n] + this._syn0[j + winWord * vecSize];
                    }
                    ++bagSize;
                }
                if (this._wordModel == Word2Vec.WordModel.CBOW && bagSize > 0) {
                    this.CBOW(curWord, sentence, sentIdx, sentLen, winSizeMod, bagSize, neu1, neu1e);
                }
                if (++wordCount % 10000 != 0) continue;
                this._nodeProcessedWords._val += 10000L;
                long totalProcessedWordsEst = this._prevTotalProcessedWords + this._nodeProcessedWords._val;
                this._curLearningRate = WordVectorTrainer.calcLearningRate(this._initLearningRate, this._epochs, totalProcessedWordsEst, this._vocabWordCount);
            }
        }
        this._processedWords = wordCount;
        this._nodeProcessedWords._val += (long)(wordCount % 10000);
        this._job.update(1L);
    }

    @Override
    public void reduce(WordVectorTrainer other) {
        this._processedWords += other._processedWords;
        if (this._syn0 != other._syn0) {
            float c = (float)other._processedWords / (float)this._processedWords;
            ArrayUtils.add(1.0f - c, this._syn0, c, other._syn0);
            ArrayUtils.add(1.0f - c, this._syn1, c, other._syn1);
            this._nodeProcessedWords._val += other._nodeProcessedWords._val;
        }
    }

    private void skipGram(int curWord, int winWord, float[] neu1e) {
        int i;
        int vecSize = this._wordVecSize;
        int l1 = winWord * vecSize;
        for (i = 0; i < vecSize; ++i) {
            neu1e[i] = 0.0f;
        }
        this.hierarchicalSoftmaxSG(curWord, l1, neu1e);
        for (i = 0; i < vecSize; ++i) {
            int n = i + l1;
            this._syn0[n] = this._syn0[n] + neu1e[i];
        }
    }

    private void hierarchicalSoftmaxSG(int targetWord, int l1, float[] neu1e) {
        int vecSize = this._wordVecSize;
        int tWrdCodeLen = this._HBWTCode[targetWord].length;
        float alpha = this._curLearningRate;
        for (int i = 0; i < tWrdCodeLen; ++i) {
            int j;
            int l2 = this._HBWTPoint[targetWord][i] * vecSize;
            float f = 0.0f;
            for (int j2 = 0; j2 < vecSize; ++j2) {
                f += this._syn0[j2 + l1] * this._syn1[j2 + l2];
            }
            if (f <= -6.0f || f >= 6.0f) continue;
            f = _expTable[(int)((f + 6.0f) * 83.0f)];
            float gradient = ((float)(1 - this._HBWTCode[targetWord][i]) - f) * alpha;
            for (j = 0; j < vecSize; ++j) {
                int n = j;
                neu1e[n] = neu1e[n] + gradient * this._syn1[j + l2];
            }
            for (j = 0; j < vecSize; ++j) {
                int n = j + l2;
                this._syn1[n] = this._syn1[n] + gradient * this._syn0[j + l1];
            }
        }
    }

    private void CBOW(int curWord, int[] sentence, int sentIdx, int sentLen, int winSizeMod, int bagSize, float[] neu1, float[] neu1e) {
        int vecSize = this._wordVecSize;
        int winSize = this._windowSize;
        int curWinSize = winSize * 2 + 1 - winSize;
        int i = 0;
        while (i < vecSize) {
            int n = i++;
            neu1[n] = neu1[n] / (float)bagSize;
        }
        this.hierarchicalSoftmaxCBOW(curWord, neu1, neu1e);
        for (int winIdx = winSizeMod; winIdx < curWinSize; ++winIdx) {
            int winWordSentIdx;
            if (winIdx == winSize || (winWordSentIdx = sentIdx - winSize + winIdx) < 0 || winWordSentIdx >= sentLen) continue;
            int winWord = sentence[winWordSentIdx];
            for (int i2 = 0; i2 < vecSize; ++i2) {
                int n = i2 + winWord * vecSize;
                this._syn0[n] = this._syn0[n] + neu1e[i2];
            }
        }
    }

    private void hierarchicalSoftmaxCBOW(int targetWord, float[] neu1, float[] neu1e) {
        int vecSize = this._wordVecSize;
        int tWrdCodeLen = this._HBWTCode[targetWord].length;
        float alpha = this._curLearningRate;
        float f = 0.0f;
        for (int i = 0; i < tWrdCodeLen; ++i) {
            int j;
            int l2 = this._HBWTPoint[targetWord][i] * vecSize;
            for (j = 0; j < vecSize; ++j) {
                f += neu1[j] * this._syn1[j + l2];
            }
            if (!(f <= -6.0f) && !(f >= 6.0f)) {
                f = _expTable[(int)((f + 6.0f) * 83.0f)];
                float gradient = ((float)(1 - this._HBWTCode[targetWord][i]) - f) * alpha;
                for (j = 0; j < vecSize; ++j) {
                    int n = j;
                    neu1e[n] = neu1e[n] + gradient * this._syn1[j + l2];
                }
                for (j = 0; j < vecSize; ++j) {
                    int n = j + l2;
                    this._syn1[n] = this._syn1[n] + gradient * neu1[j];
                }
            }
            f = 0.0f;
        }
    }

    private static float calcLearningRate(float initLearningRate, int epochs, long totalProcessed, long vocabWordCount) {
        float rate = initLearningRate * (1.0f - (float)totalProcessed / (float)((long)epochs * vocabWordCount + 1L));
        if (rate < initLearningRate * 1.0E-4f) {
            rate = initLearningRate * 1.0E-4f;
        }
        return rate;
    }

    public void updateModelInfo(Word2VecModel.Word2VecModelInfo modelInfo) {
        modelInfo._syn0 = this._syn0;
        modelInfo._syn1 = this._syn1;
        modelInfo._totalProcessedWords += this._processedWords;
    }

    private int cheapRandInt(int max) {
        this._seed ^= this._seed << 21;
        this._seed ^= this._seed >>> 35;
        this._seed ^= this._seed << 4;
        int r = (int)this._seed % max;
        return r > 0 ? r : -r;
    }

    private class ChunkSentenceIterator
    implements Iterator<int[]> {
        private Chunk _chk;
        private int _pos = 0;
        private int _len = -1;
        private int[] _sent = new int[1001];

        private ChunkSentenceIterator(Chunk chk) {
            this._chk = chk;
        }

        @Override
        public boolean hasNext() {
            return this.nextLength() >= 0;
        }

        private int nextLength() {
            if (this._len >= 0) {
                return this._len;
            }
            if (this._pos >= this._chk._len) {
                return -1;
            }
            this._len = 0;
            BufferedString tmp = new BufferedString();
            while (this._pos < this._chk._len && !this._chk.isNA(this._pos) && this._len < 1000) {
                long count;
                float ran;
                BufferedString str = this._chk.atStr(tmp, this._pos);
                if (!(!WordVectorTrainer.this._vocab.containsKey(str) || WordVectorTrainer.this._sentSampleRate > 0.0f && (ran = (float)((Math.sqrt((float)(count = ((IcedLong)((WordVectorTrainer)WordVectorTrainer.this)._wordCounts.get((Object)str))._val) / (WordVectorTrainer.this._sentSampleRate * (float)WordVectorTrainer.this._vocabWordCount)) + 1.0) * (double)(WordVectorTrainer.this._sentSampleRate * (float)WordVectorTrainer.this._vocabWordCount) / (double)count)) * 65536.0f < (float)WordVectorTrainer.this.cheapRandInt(65535))) {
                    this._sent[this._len++] = (Integer)WordVectorTrainer.this._vocab.get(tmp);
                }
                ++this._pos;
            }
            this._sent[this._len] = -1;
            ++this._pos;
            return this._len;
        }

        @Override
        public int[] next() {
            if (this.hasNext()) {
                this._len = -1;
                return this._sent;
            }
            return null;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException("Remove is not supported");
        }
    }
}

