/*
 * Decompiled with CFR 0.152.
 */
package com.medallia.word2vec;

import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMultiset;
import com.google.common.collect.ImmutableSortedMultiset;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multiset;
import com.google.common.collect.Multisets;
import com.google.common.primitives.Doubles;
import com.medallia.word2vec.Word2VecModel;
import com.medallia.word2vec.Word2VecTrainerBuilder;
import com.medallia.word2vec.huffman.HuffmanCoding;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkConfig;
import com.medallia.word2vec.neuralnetwork.NeuralNetworkTrainer;
import com.medallia.word2vec.util.AC;
import com.medallia.word2vec.util.ProfilingTimer;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;

class Word2VecTrainer {
    private final int minFrequency;
    private final Optional<Multiset<String>> vocab;
    private final NeuralNetworkConfig neuralNetworkConfig;

    Word2VecTrainer(Integer minFrequency, Optional<Multiset<String>> vocab, NeuralNetworkConfig neuralNetworkConfig) {
        this.vocab = vocab;
        this.minFrequency = minFrequency;
        this.neuralNetworkConfig = neuralNetworkConfig;
    }

    private static Multiset<String> count(Iterable<String> tokens) {
        HashMultiset counts = HashMultiset.create();
        for (String token : tokens) {
            counts.add((Object)token);
        }
        return counts;
    }

    private ImmutableMultiset<String> filterAndSort(final Multiset<String> counts) {
        return Multisets.copyHighestCountFirst((Multiset)ImmutableSortedMultiset.copyOf((Iterable)Multisets.filter(counts, (Predicate)new Predicate<String>(){

            public boolean apply(String s) {
                return counts.count((Object)s) >= Word2VecTrainer.this.minFrequency;
            }
        })));
    }

    Word2VecModel train(Log log, Word2VecTrainerBuilder.TrainingProgressListener listener, Iterable<List<String>> sentences) throws InterruptedException {
        try (ProfilingTimer timer = ProfilingTimer.createLoggingSubtasks(log, "Training word2vec", new Object[0]);){
            NeuralNetworkTrainer.NeuralNetworkModel model;
            Map<String, HuffmanCoding.HuffmanNode> huffmanNodes;
            ImmutableMultiset<String> vocab;
            Multiset counts;
            try (AC ac = timer.start("Acquiring word frequencies", new Object[0]);){
                listener.update(Word2VecTrainerBuilder.TrainingProgressListener.Stage.ACQUIRE_VOCAB, 0.0);
                counts = this.vocab.isPresent() ? (Multiset)this.vocab.get() : Word2VecTrainer.count(Iterables.concat(sentences));
            }
            try (AC ac = timer.start("Filtering and sorting vocabulary", new Object[0]);){
                listener.update(Word2VecTrainerBuilder.TrainingProgressListener.Stage.FILTER_SORT_VOCAB, 0.0);
                vocab = this.filterAndSort((Multiset<String>)counts);
            }
            AC task = timer.start("Create Huffman encoding", new Object[0]);
            Object object = null;
            try {
                huffmanNodes = new HuffmanCoding(vocab, listener).encode();
            }
            catch (Throwable throwable) {
                object = throwable;
                throw throwable;
            }
            finally {
                if (task != null) {
                    if (object != null) {
                        try {
                            task.close();
                        }
                        catch (Throwable throwable) {
                            ((Throwable)object).addSuppressed(throwable);
                        }
                    } else {
                        task.close();
                    }
                }
            }
            try (AC task2 = timer.start("Training model %s", this.neuralNetworkConfig);){
                model = this.neuralNetworkConfig.createTrainer(vocab, huffmanNodes, listener).train(sentences);
            }
            object = new Word2VecModel((Iterable<String>)vocab.elementSet(), model.layerSize(), Doubles.concat((double[][])model.vectors()));
            return object;
        }
    }
}

