/*
 * Decompiled with CFR 0.152.
 */
package com.medallia.word2vec.neuralnetwork;

import com.google.common.collect.Iterables;
import com.google.common.collect.Multiset;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.medallia.word2vec.Word2VecTrainerBuilder;
import com.medallia.word2vec.huffman.HuffmanCoding;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkConfig;
import com.medallia.word2vec.util.CallableVoid;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;

public abstract class NeuralNetworkTrainer {
    private static final int MAX_SENTENCE_LENGTH = 1000;
    static final int MAX_EXP = 6;
    static final int EXP_TABLE_SIZE = 1000;
    static final double[] EXP_TABLE = new double[1000];
    private static final int TABLE_SIZE = 100000000;
    private final Word2VecTrainerBuilder.TrainingProgressListener listener;
    final NeuralNetworkConfig config;
    final Map<String, HuffmanCoding.HuffmanNode> huffmanNodes;
    private final int vocabSize;
    final int layer1_size;
    final int window;
    int numTrainedTokens;
    protected final AtomicInteger actualWordCount;
    volatile double alpha;
    final double[][] syn0;
    final double[][] syn1;
    private final double[][] syn1neg;
    private final int[] table;
    long startNano;

    NeuralNetworkTrainer(NeuralNetworkConfig config, Multiset<String> vocab, Map<String, HuffmanCoding.HuffmanNode> huffmanNodes, Word2VecTrainerBuilder.TrainingProgressListener listener) {
        this.config = config;
        this.huffmanNodes = huffmanNodes;
        this.listener = listener;
        this.vocabSize = huffmanNodes.size();
        this.numTrainedTokens = vocab.size();
        this.layer1_size = config.layerSize;
        this.window = config.windowSize;
        this.actualWordCount = new AtomicInteger();
        this.alpha = config.initialLearningRate;
        this.syn0 = new double[this.vocabSize][this.layer1_size];
        this.syn1 = new double[this.vocabSize][this.layer1_size];
        this.syn1neg = new double[this.vocabSize][this.layer1_size];
        this.table = new int[100000000];
        this.initializeSyn0();
        this.initializeUnigramTable();
    }

    private void initializeUnigramTable() {
        long trainWordsPow = 0L;
        double power = 0.75;
        for (HuffmanCoding.HuffmanNode node : this.huffmanNodes.values()) {
            trainWordsPow = (long)((double)trainWordsPow + Math.pow(node.count, power));
        }
        Iterator<HuffmanCoding.HuffmanNode> nodeIter = this.huffmanNodes.values().iterator();
        HuffmanCoding.HuffmanNode last = nodeIter.next();
        double d1 = Math.pow(last.count, power) / (double)trainWordsPow;
        int i = 0;
        for (int a = 0; a < 100000000; ++a) {
            this.table[a] = i++;
            if (!((double)a / 1.0E8 > d1)) continue;
            HuffmanCoding.HuffmanNode next = nodeIter.hasNext() ? nodeIter.next() : last;
            d1 += Math.pow(next.count, power) / (double)trainWordsPow;
            last = next;
        }
    }

    private void initializeSyn0() {
        long nextRandom = 1L;
        for (int a = 0; a < this.huffmanNodes.size(); ++a) {
            nextRandom = NeuralNetworkTrainer.incrementRandom(nextRandom);
            for (int b = 0; b < this.layer1_size; ++b) {
                nextRandom = NeuralNetworkTrainer.incrementRandom(nextRandom);
                this.syn0[a][b] = ((double)(nextRandom & 0xFFFFL) / 65536.0 - 0.5) / (double)this.layer1_size;
            }
        }
    }

    static long incrementRandom(long r) {
        return r * 25214903917L + 11L;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public NeuralNetworkModel train(Iterable<List<String>> sentences) throws InterruptedException {
        ListeningExecutorService ex = MoreExecutors.listeningDecorator((ExecutorService)Executors.newFixedThreadPool(this.config.numThreads));
        int numSentences = Iterables.size(sentences);
        this.numTrainedTokens += numSentences;
        Iterable partitioned = Iterables.partition(sentences, (int)(numSentences / this.config.numThreads + 1));
        try {
            this.listener.update(Word2VecTrainerBuilder.TrainingProgressListener.Stage.TRAIN_NEURAL_NETWORK, 0.0);
            for (int iter = this.config.iterations; iter > 0; --iter) {
                Object batch2;
                ArrayList<Worker> tasks = new ArrayList<Worker>();
                int i = 0;
                for (Object batch2 : partitioned) {
                    tasks.add(this.createWorker(i, iter, (Iterable<List<String>>)batch2));
                    ++i;
                }
                ArrayList<ListenableFuture> futures = new ArrayList<ListenableFuture>(tasks.size());
                batch2 = tasks.iterator();
                while (batch2.hasNext()) {
                    CallableVoid task = (CallableVoid)batch2.next();
                    futures.add(ex.submit((Callable)task));
                }
                try {
                    Futures.allAsList(futures).get();
                    continue;
                }
                catch (ExecutionException e) {
                    throw new IllegalStateException("Error training neural network", e.getCause());
                }
            }
            ex.shutdown();
        }
        finally {
            ex.shutdownNow();
        }
        return new NeuralNetworkModel(){

            @Override
            public int layerSize() {
                return NeuralNetworkTrainer.this.config.layerSize;
            }

            @Override
            public double[][] vectors() {
                return NeuralNetworkTrainer.this.syn0;
            }
        };
    }

    abstract Worker createWorker(int var1, int var2, Iterable<List<String>> var3);

    static {
        for (int i = 0; i < 1000; ++i) {
            NeuralNetworkTrainer.EXP_TABLE[i] = Math.exp(((double)i / 1000.0 * 2.0 - 1.0) * 6.0);
            int n = i;
            EXP_TABLE[n] = EXP_TABLE[n] / (EXP_TABLE[i] + 1.0);
        }
    }

    abstract class Worker
    extends CallableVoid {
        private static final int LEARNING_RATE_UPDATE_FREQUENCY = 10000;
        long nextRandom;
        final int iter;
        final Iterable<List<String>> batch;
        int wordCount;
        int lastWordCount;
        final double[] neu1;
        final double[] neu1e;

        Worker(int randomSeed, int iter, Iterable<List<String>> batch) {
            this.neu1 = new double[NeuralNetworkTrainer.this.layer1_size];
            this.neu1e = new double[NeuralNetworkTrainer.this.layer1_size];
            this.nextRandom = randomSeed;
            this.iter = iter;
            this.batch = batch;
        }

        @Override
        public void run() throws InterruptedException {
            for (List<String> sentence : this.batch) {
                ArrayList<String> filteredSentence = new ArrayList<String>(sentence.size());
                for (String s : sentence) {
                    if (!NeuralNetworkTrainer.this.huffmanNodes.containsKey(s)) continue;
                    ++this.wordCount;
                    if (NeuralNetworkTrainer.this.config.downSampleRate > 0.0) {
                        HuffmanCoding.HuffmanNode huffmanNode = NeuralNetworkTrainer.this.huffmanNodes.get(s);
                        double random = (Math.sqrt((double)huffmanNode.count / (NeuralNetworkTrainer.this.config.downSampleRate * (double)NeuralNetworkTrainer.this.numTrainedTokens)) + 1.0) * (NeuralNetworkTrainer.this.config.downSampleRate * (double)NeuralNetworkTrainer.this.numTrainedTokens) / (double)huffmanNode.count;
                        this.nextRandom = NeuralNetworkTrainer.incrementRandom(this.nextRandom);
                        if (random < (double)(this.nextRandom & 0xFFFFL) / 65536.0) continue;
                    }
                    filteredSentence.add(s);
                }
                ++this.wordCount;
                Iterable partitioned = Iterables.partition(filteredSentence, (int)1000);
                for (List chunked : partitioned) {
                    if (Thread.currentThread().isInterrupted()) {
                        throw new InterruptedException("Interrupted while training word2vec model");
                    }
                    if (this.wordCount - this.lastWordCount > 10000) {
                        this.updateAlpha(this.iter);
                    }
                    this.trainSentence(chunked);
                }
            }
            NeuralNetworkTrainer.this.actualWordCount.addAndGet(this.wordCount - this.lastWordCount);
        }

        private void updateAlpha(int iter) {
            int currentActual = NeuralNetworkTrainer.this.actualWordCount.addAndGet(this.wordCount - this.lastWordCount);
            this.lastWordCount = this.wordCount;
            NeuralNetworkTrainer.this.alpha = NeuralNetworkTrainer.this.config.initialLearningRate * Math.max(1.0 - (double)currentActual / (double)(NeuralNetworkTrainer.this.config.iterations * NeuralNetworkTrainer.this.numTrainedTokens), 1.0E-4);
            NeuralNetworkTrainer.this.listener.update(Word2VecTrainerBuilder.TrainingProgressListener.Stage.TRAIN_NEURAL_NETWORK, (double)currentActual / (double)(NeuralNetworkTrainer.this.config.iterations * NeuralNetworkTrainer.this.numTrainedTokens + 1));
        }

        void handleNegativeSampling(HuffmanCoding.HuffmanNode huffmanNode) {
            for (int d = 0; d <= NeuralNetworkTrainer.this.config.negativeSamples; ++d) {
                int c;
                int label;
                int target;
                if (d == 0) {
                    target = huffmanNode.idx;
                    label = 1;
                } else {
                    this.nextRandom = NeuralNetworkTrainer.incrementRandom(this.nextRandom);
                    target = NeuralNetworkTrainer.this.table[(int)((this.nextRandom >> 16) % 100000000L + 100000000L) % 100000000];
                    if (target == 0) {
                        target = (int)((this.nextRandom % (long)(NeuralNetworkTrainer.this.vocabSize - 1) + (long)NeuralNetworkTrainer.this.vocabSize - 1L) % (long)(NeuralNetworkTrainer.this.vocabSize - 1)) + 1;
                    }
                    if (target == huffmanNode.idx) continue;
                    label = 0;
                }
                int l2 = target;
                double f = 0.0;
                for (int c2 = 0; c2 < NeuralNetworkTrainer.this.layer1_size; ++c2) {
                    f += this.neu1[c2] * NeuralNetworkTrainer.this.syn1neg[l2][c2];
                }
                double g = f > 6.0 ? (double)(label - 1) * NeuralNetworkTrainer.this.alpha : (f < -6.0 ? (double)(label - 0) * NeuralNetworkTrainer.this.alpha : ((double)label - EXP_TABLE[(int)((f + 6.0) * 83.0)]) * NeuralNetworkTrainer.this.alpha);
                for (c = 0; c < NeuralNetworkTrainer.this.layer1_size; ++c) {
                    int n = c;
                    this.neu1e[n] = this.neu1e[n] + g * NeuralNetworkTrainer.this.syn1neg[l2][c];
                }
                for (c = 0; c < NeuralNetworkTrainer.this.layer1_size; ++c) {
                    double[] dArray = NeuralNetworkTrainer.this.syn1neg[l2];
                    int n = c;
                    dArray[n] = dArray[n] + g * this.neu1[c];
                }
            }
        }

        abstract void trainSentence(List<String> var1);
    }

    public static interface NeuralNetworkModel {
        public int layerSize();

        public double[][] vectors();
    }
}

