/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec.wordstore;

import com.google.common.util.concurrent.AtomicDouble;
import it.unimi.dsi.util.XorShift64StarRandomGenerator;
import java.io.File;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.movingwindow.Util;
import org.deeplearning4j.util.Index;
import org.deeplearning4j.util.SerializationUtils;
import org.nd4j.linalg.api.ndarray.INDArray;

public abstract class BaseLookupCache
implements VocabCache,
Serializable {
    protected Index wordIndex = new Index();
    protected boolean useAdaGrad = false;
    protected Counter<String> wordFrequencies = Util.parallelCounter();
    protected Counter<String> docFrequencies = Util.parallelCounter();
    protected Map<String, VocabWord> vocabs = new ConcurrentHashMap<String, VocabWord>();
    protected Map<String, VocabWord> tokens = new ConcurrentHashMap<String, VocabWord>();
    protected Map<Integer, INDArray> codes = new ConcurrentHashMap<Integer, INDArray>();
    protected int vectorLength = 50;
    protected transient RandomGenerator rng = new XorShift64StarRandomGenerator(123L);
    protected AtomicInteger totalWordOccurrences = new AtomicInteger(0);
    protected AtomicDouble lr = new AtomicDouble(0.1);
    protected long seed = 123L;
    protected int numDocs = 0;
    protected double negative = 0.0;

    public BaseLookupCache(int vectorLength, boolean useAdaGrad, double lr, RandomGenerator gen, double negative) {
        this.vectorLength = vectorLength;
        this.useAdaGrad = useAdaGrad;
        this.lr.set(lr);
        this.rng = gen;
        this.addToken(new VocabWord(1.0, "UNK"));
        this.addWordToIndex(0, "UNK");
        this.putVocabWord("UNK");
        this.negative = negative;
    }

    @Override
    public synchronized Collection<String> words() {
        return this.vocabs.keySet();
    }

    @Override
    public void resetWeights() {
        this.rng = new MersenneTwister(this.seed);
    }

    @Override
    public void incrementWordCount(String word) {
        this.incrementWordCount(word, 1);
    }

    @Override
    public void incrementWordCount(String word, int increment) {
        this.wordFrequencies.incrementCount((Object)word, 1.0);
        VocabWord token = this.hasToken(word) ? this.tokenFor(word) : new VocabWord(increment, word);
        token.increment(increment);
        this.totalWordOccurrences.set(this.totalWordOccurrences.get() + increment);
    }

    @Override
    public int wordFrequency(String word) {
        return (int)this.wordFrequencies.getCount((Object)word);
    }

    @Override
    public boolean containsWord(String word) {
        return this.vocabs.containsKey(word);
    }

    @Override
    public String wordAtIndex(int index) {
        return (String)this.wordIndex.get(index);
    }

    @Override
    public int indexOf(String word) {
        return this.wordIndex.indexOf((Object)word);
    }

    @Override
    public void putCode(int codeIndex, INDArray code) {
        this.codes.put(codeIndex, code);
    }

    @Override
    public Collection<VocabWord> vocabWords() {
        return this.vocabs.values();
    }

    @Override
    public int totalWordOccurrences() {
        return this.totalWordOccurrences.get();
    }

    @Override
    public VocabWord wordFor(String word) {
        return this.vocabs.get(word);
    }

    @Override
    public synchronized void addWordToIndex(int index, String word) {
        if (!this.wordFrequencies.containsKey((Object)word)) {
            this.wordFrequencies.incrementCount((Object)word, 1.0);
        }
        this.wordIndex.add((Object)word, index);
    }

    @Override
    public synchronized void putVocabWord(String word) {
        VocabWord token = this.tokenFor(word);
        this.addWordToIndex(token.getIndex(), word);
        if (!this.hasToken(word)) {
            throw new IllegalStateException("Unable to add token " + word + " when not already a token");
        }
        this.vocabs.put(word, token);
        this.wordIndex.add((Object)word, token.getIndex());
    }

    @Override
    public synchronized int numWords() {
        return this.vocabs.size();
    }

    @Override
    public int docAppearedIn(String word) {
        return (int)this.docFrequencies.getCount((Object)word);
    }

    @Override
    public void incrementDocCount(String word, int howMuch) {
        this.docFrequencies.incrementCount((Object)word, (double)howMuch);
    }

    @Override
    public void setCountForDoc(String word, int count) {
        this.docFrequencies.setCount((Object)word, (double)count);
    }

    @Override
    public int totalNumberOfDocs() {
        return this.numDocs;
    }

    @Override
    public void incrementTotalDocCount() {
        ++this.numDocs;
    }

    @Override
    public void incrementTotalDocCount(int by) {
        this.numDocs += by;
    }

    @Override
    public Collection<VocabWord> tokens() {
        return this.tokens.values();
    }

    @Override
    public void addToken(VocabWord word) {
        this.tokens.put(word.getWord(), word);
    }

    @Override
    public VocabWord tokenFor(String word) {
        return this.tokens.get(word);
    }

    @Override
    public boolean hasToken(String token) {
        return this.tokenFor(token) != null;
    }

    @Override
    public void setLearningRate(double lr) {
        this.lr.set(lr);
    }

    @Override
    public void saveVocab() {
        SerializationUtils.saveObject((Object)this, (File)new File("ser"));
    }

    @Override
    public boolean vocabExists() {
        return new File("ser").exists();
    }

    @Override
    public void loadVocab() {
        BaseLookupCache cache = (BaseLookupCache)SerializationUtils.readObject((File)new File("ser"));
        this.codes = cache.codes;
        this.vocabs = cache.vocabs;
        this.vectorLength = cache.vectorLength;
        this.wordFrequencies = cache.wordFrequencies;
        this.wordIndex = cache.wordIndex;
        this.tokens = cache.tokens;
    }

    public RandomGenerator getRng() {
        return this.rng;
    }

    public void setRng(RandomGenerator rng) {
        this.rng = rng;
    }

    public static abstract class Builder {
        protected int vectorLength = 100;
        protected boolean useAdaGrad = false;
        protected double lr = 0.025;
        protected RandomGenerator gen = new XorShift64StarRandomGenerator(123L);
        protected long seed = 123L;
        protected double negative = 0.0;

        public Builder negative(double negative) {
            this.negative = negative;
            return this;
        }

        public Builder vectorLength(int vectorLength) {
            this.vectorLength = vectorLength;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder lr(double lr) {
            this.lr = lr;
            return this;
        }

        public Builder gen(RandomGenerator gen) {
            this.gen = gen;
            return this;
        }

        public Builder seed(long seed) {
            this.seed = seed;
            return this;
        }

        public abstract BaseLookupCache build();
    }

    protected abstract class WeightIterator
    implements Iterator<INDArray> {
        protected int currIndex = 0;

        protected WeightIterator() {
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }
}

