/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec.wordstore.inmemory;

import it.unimi.dsi.util.XorShift64StarRandomGenerator;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.plot.Tsne;
import org.deeplearning4j.plot.dropwizard.RenderApplication;
import org.deeplearning4j.text.movingwindow.Util;
import org.deeplearning4j.util.Index;
import org.deeplearning4j.util.SerializationUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class InMemoryLookupCache
implements VocabCache,
Serializable {
    private Index wordIndex = new Index();
    private boolean useAdaGrad = false;
    private Counter<String> wordFrequencies = Util.parallelCounter();
    private Counter<String> docFrequencies = Util.parallelCounter();
    private Map<String, VocabWord> vocabs = new ConcurrentHashMap<String, VocabWord>();
    private Map<String, VocabWord> tokens = new ConcurrentHashMap<String, VocabWord>();
    private Map<Integer, INDArray> codes = new ConcurrentHashMap<Integer, INDArray>();
    private INDArray syn0;
    private INDArray syn1;
    private int vectorLength = 50;
    private transient RandomGenerator rng = new XorShift64StarRandomGenerator(123L);
    private AtomicInteger totalWordOccurrences = new AtomicInteger(0);
    private double lr = 0.1f;
    double[] expTable = new double[1000];
    static double MAX_EXP = 6.0;
    private long seed = 123L;
    private int numDocs = 0;

    public InMemoryLookupCache(int vectorLength) {
        this(vectorLength, true);
        this.initExpTable();
    }

    public InMemoryLookupCache(int vectorLength, int vocabSize) {
        this.vectorLength = vectorLength;
        this.syn0 = Nd4j.rand((int)vocabSize, (int)vectorLength);
    }

    public InMemoryLookupCache(int vectorLength, boolean useAdaGrad) {
        this(vectorLength, useAdaGrad, 0.025f, (RandomGenerator)new XorShift64StarRandomGenerator(123L));
        this.addWordToIndex(0, "UNK");
        this.wordIndex.add((Object)"UNK");
    }

    public InMemoryLookupCache(int vectorLength, boolean useAdaGrad, double lr, RandomGenerator gen) {
        this.vectorLength = vectorLength;
        this.useAdaGrad = useAdaGrad;
        this.lr = lr;
        this.rng = gen;
        this.initExpTable();
    }

    public InMemoryLookupCache(int vectorLength, boolean useAdaGrad, double lr) {
        this(vectorLength, useAdaGrad, lr, (RandomGenerator)new XorShift64StarRandomGenerator(123L));
    }

    private void initExpTable() {
        for (int i = 0; i < this.expTable.length; ++i) {
            double tmp = FastMath.exp((double)(((double)i / (double)this.expTable.length * 2.0 - 1.0) * MAX_EXP));
            this.expTable[i] = tmp / (tmp + 1.0);
        }
    }

    @Override
    public void iterate(VocabWord w1, VocabWord w2) {
        if (w2.getIndex() < 0) {
            return;
        }
        INDArray l1 = this.syn0.slice(w2.getIndex());
        INDArray neu1e = Nd4j.create((int)this.vectorLength);
        double avgChange = 0.0;
        for (int i = 0; i < w1.getCodeLength(); ++i) {
            int idx;
            int code = w1.getCodes()[i];
            int point = w1.getPoints()[i];
            if (point >= this.syn0.rows() || point < 0) {
                throw new IllegalStateException("Illegal point " + point);
            }
            INDArray syn1 = this.syn1.slice(point);
            double dot = Nd4j.getBlasWrapper().dot(l1, syn1);
            if (dot < -MAX_EXP || dot >= MAX_EXP || (idx = (int)((dot + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))) >= this.expTable.length) continue;
            double f = this.expTable[idx];
            double g = ((double)(1 - code) - f) * this.lr;
            avgChange += g;
            if (this.syn0.data().dataType().equals("double")) {
                Nd4j.getBlasWrapper().axpy(g, syn1, neu1e);
                Nd4j.getBlasWrapper().axpy(g, l1, syn1);
                continue;
            }
            Nd4j.getBlasWrapper().axpy((float)g, syn1, neu1e);
            Nd4j.getBlasWrapper().axpy((float)g, l1, syn1);
        }
        avgChange /= (double)w1.getCodes().length;
        if (this.useAdaGrad) {
            if (this.syn0.data().dataType().equals("double")) {
                Nd4j.getBlasWrapper().axpy(avgChange, neu1e, l1);
            } else {
                Nd4j.getBlasWrapper().axpy((float)avgChange, neu1e, l1);
            }
        } else if (this.syn0.data().dataType().equals("double")) {
            Nd4j.getBlasWrapper().axpy(1.0, neu1e, l1);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, neu1e, l1);
        }
    }

    @Override
    public synchronized Collection<String> words() {
        return this.vocabs.keySet();
    }

    @Override
    public void resetWeights() {
        this.rng = new MersenneTwister(this.seed);
        this.syn0 = Nd4j.rand((int[])new int[]{this.vocabs.size(), this.vectorLength}, (RandomGenerator)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
        this.syn1 = Nd4j.create((int[])this.syn0.shape());
    }

    @Override
    public synchronized void incrementWordCount(String word) {
        this.incrementWordCount(word, 1);
    }

    @Override
    public synchronized void incrementWordCount(String word, int increment) {
        this.wordFrequencies.incrementCount((Object)word, 1.0);
        VocabWord token = this.hasToken(word) ? this.tokenFor(word) : new VocabWord(increment, word);
        token.increment(increment);
        this.totalWordOccurrences.set(this.totalWordOccurrences.get() + increment);
    }

    @Override
    public int wordFrequency(String word) {
        return (int)this.wordFrequencies.getCount((Object)word);
    }

    @Override
    public boolean containsWord(String word) {
        return this.vocabs.containsKey(word);
    }

    @Override
    public String wordAtIndex(int index) {
        return (String)this.wordIndex.get(index);
    }

    @Override
    public int indexOf(String word) {
        return this.wordIndex.indexOf((Object)word);
    }

    @Override
    public void putCode(int codeIndex, INDArray code) {
        this.codes.put(codeIndex, code);
    }

    @Override
    public INDArray loadCodes(int[] codes) {
        return this.syn1.getRows(codes);
    }

    @Override
    public Collection<VocabWord> vocabWords() {
        return this.vocabs.values();
    }

    @Override
    public int totalWordOccurrences() {
        return this.totalWordOccurrences.get();
    }

    @Override
    public void putVector(String word, INDArray vector) {
        if (word == null) {
            throw new IllegalArgumentException("No null words allowed");
        }
        if (vector == null) {
            throw new IllegalArgumentException("No null vectors allowed");
        }
        int idx = this.indexOf(word);
        this.syn0.slice(idx).assign(vector);
    }

    @Override
    public INDArray vector(String word) {
        if (word == null) {
            return null;
        }
        return this.syn0.getRow(this.indexOf(word));
    }

    @Override
    public VocabWord wordFor(String word) {
        return this.vocabs.get(word);
    }

    @Override
    public synchronized void addWordToIndex(int index, String word) {
        if (!this.wordFrequencies.containsKey((Object)word)) {
            this.wordFrequencies.incrementCount((Object)word, 1.0);
        }
        this.wordIndex.add((Object)word, index);
    }

    @Override
    public synchronized void putVocabWord(String word) {
        VocabWord token = this.tokenFor(word);
        this.addWordToIndex(token.getIndex(), word);
        if (!this.hasToken(word)) {
            throw new IllegalStateException("Unable to add token " + word + " when not already a token");
        }
        this.vocabs.put(word, token);
        this.wordIndex.add((Object)word, token.getIndex());
    }

    @Override
    public synchronized int numWords() {
        return this.vocabs.size();
    }

    @Override
    public int docAppearedIn(String word) {
        return (int)this.docFrequencies.getCount((Object)word);
    }

    @Override
    public void incrementDocCount(String word, int howMuch) {
        this.docFrequencies.incrementCount((Object)word, (double)howMuch);
    }

    @Override
    public void setCountForDoc(String word, int count) {
        this.docFrequencies.setCount((Object)word, (double)count);
    }

    @Override
    public int totalNumberOfDocs() {
        return this.numDocs;
    }

    @Override
    public void incrementTotalDocCount() {
        ++this.numDocs;
    }

    @Override
    public void incrementTotalDocCount(int by) {
        this.numDocs += by;
    }

    @Override
    public Collection<VocabWord> tokens() {
        return this.tokens.values();
    }

    @Override
    public void addToken(VocabWord word) {
        this.tokens.put(word.getWord(), word);
    }

    @Override
    public VocabWord tokenFor(String word) {
        return this.tokens.get(word);
    }

    @Override
    public boolean hasToken(String token) {
        return this.tokenFor(token) != null;
    }

    @Override
    public void saveVocab() {
        SerializationUtils.saveObject((Object)this, (File)new File("cache.ser"));
    }

    @Override
    public boolean vocabExists() {
        return new File("cache.ser").exists();
    }

    @Override
    public void plotVocab(Tsne tsne) {
        try {
            ArrayList<String> plot = new ArrayList<String>();
            for (String s : this.words()) {
                plot.add(s);
            }
            tsne.plot(this.syn0, 2, plot);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        try {
            RenderApplication.main(null);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public void plotVocab() {
        Tsne tsne = new Tsne.Builder().normalize(false).setFinalMomentum((double)0.8f).setMaxIter(1000).build();
        try {
            ArrayList<String> plot = new ArrayList<String>();
            for (String s : this.words()) {
                plot.add(s);
            }
            tsne.plot(this.syn0, 2, plot);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void loadVocab() {
        InMemoryLookupCache cache = (InMemoryLookupCache)SerializationUtils.readObject((File)new File("cache.ser"));
        this.codes = cache.codes;
        this.vocabs = cache.vocabs;
        this.vectorLength = cache.vectorLength;
        this.wordFrequencies = cache.wordFrequencies;
        this.wordIndex = cache.wordIndex;
        this.tokens = cache.tokens;
    }

    public RandomGenerator getRng() {
        return this.rng;
    }

    public void setRng(RandomGenerator rng) {
        this.rng = rng;
    }

    public INDArray getSyn0() {
        return this.syn0;
    }

    public void setSyn0(INDArray syn0) {
        this.syn0 = syn0;
    }

    public INDArray getSyn1() {
        return this.syn1;
    }

    public void setSyn1(INDArray syn1) {
        this.syn1 = syn1;
    }
}

