/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.inmemory;

import com.google.common.util.concurrent.AtomicDouble;
import it.unimi.dsi.util.XorShift64StarRandomGenerator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.plot.Tsne;
import org.deeplearning4j.plot.dropwizard.RenderApplication;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class InMemoryLookupTable
implements WeightLookupTable {
    protected INDArray syn0;
    protected INDArray syn1;
    protected int vectorLength = 50;
    protected transient RandomGenerator rng = new XorShift64StarRandomGenerator(123L);
    protected AtomicDouble lr = new AtomicDouble(0.1);
    protected double[] expTable = new double[1000];
    protected static double MAX_EXP = 6.0;
    protected long seed = 123L;
    protected INDArray table;
    protected INDArray syn1Neg;
    protected boolean useAdaGrad;
    protected double negative = 0.0;
    protected VocabCache vocab;
    protected Map<Integer, INDArray> codes = new ConcurrentHashMap<Integer, INDArray>();

    public InMemoryLookupTable(VocabCache vocab, int vectorLength, boolean useAdaGrad, double lr, RandomGenerator gen, double negative) {
        this.vocab = vocab;
        this.vectorLength = vectorLength;
        this.useAdaGrad = useAdaGrad;
        this.lr.set(lr);
        this.rng = gen;
        this.negative = negative;
        this.initExpTable();
    }

    @Override
    public int layerSize() {
        return this.vectorLength;
    }

    @Override
    public void resetWeights(boolean reset) {
        if (this.rng == null) {
            this.rng = new MersenneTwister(this.seed);
        }
        if (this.syn0 == null || this.syn0 != null && reset) {
            this.syn0 = Nd4j.rand((int[])new int[]{this.vocab.numWords() + 1, this.vectorLength}, (RandomGenerator)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
            INDArray randUnk = Nd4j.rand((int)1, (int)this.vectorLength, (RandomGenerator)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
            this.putVector("UNK", randUnk);
        }
        if (this.syn1 == null || this.syn1 != null && reset) {
            this.syn1 = Nd4j.create((int[])this.syn0.shape());
        }
        this.initNegative();
    }

    @Override
    public void plotVocab(Tsne tsne) {
        try {
            ArrayList<String> plot = new ArrayList<String>();
            for (String s : this.vocab.words()) {
                plot.add(s);
            }
            tsne.plot(this.syn0, 2, plot);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        try {
            RenderApplication.main(null);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override
    public void plotVocab() {
        Tsne tsne = new Tsne.Builder().normalize(false).setFinalMomentum((double)0.8f).setMaxIter(1000).build();
        try {
            ArrayList<String> plot = new ArrayList<String>();
            for (String s : this.vocab.words()) {
                plot.add(s);
            }
            tsne.plot(this.syn0, 2, plot);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void putCode(int codeIndex, INDArray code) {
        this.codes.put(codeIndex, code);
    }

    @Override
    public INDArray loadCodes(int[] codes) {
        return this.syn1.getRows(codes);
    }

    protected void initNegative() {
        if (this.negative > 0.0) {
            this.syn1Neg = Nd4j.zeros((int[])this.syn0.shape());
            this.makeTable(10000, 0.75);
        }
    }

    protected void initExpTable() {
        for (int i = 0; i < this.expTable.length; ++i) {
            double tmp = FastMath.exp((double)(((double)i / (double)this.expTable.length * 2.0 - 1.0) * MAX_EXP));
            this.expTable[i] = tmp / (tmp + 1.0);
        }
    }

    @Override
    public void iterateSample(VocabWord w1, VocabWord w2, AtomicLong nextRandom, double alpha) {
        INDArray neu1e;
        INDArray l1;
        block20: {
            if (w2 == null || w2.getIndex() < 0) {
                return;
            }
            l1 = this.syn0.slice(w2.getIndex());
            neu1e = Nd4j.create((int)this.vectorLength);
            for (int i = 0; i < w1.getCodeLength(); ++i) {
                double g;
                int idx;
                int code = w1.getCodes().get(i);
                int point = w1.getPoints().get(i);
                if (point >= this.syn0.rows() || point < 0) {
                    throw new IllegalStateException("Illegal point " + point);
                }
                INDArray syn1 = this.syn1.slice(point);
                double dot = Nd4j.getBlasWrapper().dot(l1, syn1);
                if (dot < -MAX_EXP || dot >= MAX_EXP || (idx = (int)((dot + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))) >= this.expTable.length) continue;
                double f = this.expTable[idx];
                double d = g = this.useAdaGrad ? w1.getGradient(i, (double)(1 - code) - f) : ((double)(1 - code) - f) * alpha;
                if (neu1e.data().dataType() == 1) {
                    Nd4j.getBlasWrapper().axpy((float)g, syn1, neu1e);
                    Nd4j.getBlasWrapper().axpy((float)g, l1, syn1);
                    continue;
                }
                Nd4j.getBlasWrapper().axpy(g, syn1, neu1e);
                Nd4j.getBlasWrapper().axpy(g, l1, syn1);
            }
            int target = w1.getIndex();
            if (!(this.negative > 0.0)) break block20;
            int d = 0;
            while ((double)d < this.negative + 1.0) {
                block23: {
                    int label;
                    block22: {
                        block21: {
                            if (d != 0) break block21;
                            label = 1;
                            break block22;
                        }
                        nextRandom.set(nextRandom.get() * 25214903917L + 11L);
                        int idx = Math.abs((int)(nextRandom.get() >> 16) % this.table.length());
                        target = this.table.getInt(new int[]{idx});
                        if (target <= 0) {
                            target = (int)nextRandom.get() % (this.vocab.numWords() - 1) + 1;
                        }
                        if (target == w1.getIndex()) break block23;
                        label = 0;
                    }
                    if (target < this.syn1Neg.rows() && target >= 0) {
                        double g;
                        double f = Nd4j.getBlasWrapper().dot(l1, this.syn1Neg.slice(target));
                        if (f > MAX_EXP) {
                            g = this.useAdaGrad ? w1.getGradient(target, label - 1) : (double)(label - 1) * alpha;
                        } else if (f < -MAX_EXP) {
                            g = (double)(label - 0) * (this.useAdaGrad ? w1.getGradient(target, alpha) : alpha);
                        } else {
                            double d2 = g = this.useAdaGrad ? w1.getGradient(target, (double)label - this.expTable[(int)((f + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))]) : ((double)label - this.expTable[(int)((f + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))]) * alpha;
                        }
                        if (this.syn0.data().dataType() == 0) {
                            Nd4j.getBlasWrapper().axpy(g, neu1e, l1);
                        } else {
                            Nd4j.getBlasWrapper().axpy((float)g, neu1e, l1);
                        }
                        if (this.syn0.data().dataType() == 0) {
                            Nd4j.getBlasWrapper().axpy(g, this.syn1Neg.slice(target), l1);
                        } else {
                            Nd4j.getBlasWrapper().axpy((float)g, this.syn1Neg.slice(target), l1);
                        }
                    }
                }
                ++d;
            }
        }
        if (this.syn0.data().dataType() == 0) {
            Nd4j.getBlasWrapper().axpy(1.0, neu1e, l1);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, neu1e, l1);
        }
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public void setUseAdaGrad(boolean useAdaGrad) {
        this.useAdaGrad = useAdaGrad;
    }

    public double getNegative() {
        return this.negative;
    }

    public void setNegative(double negative) {
        this.negative = negative;
    }

    @Override
    public void iterate(VocabWord w1, VocabWord w2) {
        if (w2.getIndex() < 0) {
            return;
        }
        INDArray l1 = this.syn0.slice(w2.getIndex());
        INDArray neu1e = Nd4j.create((int)this.vectorLength);
        double alpha = this.lr.get();
        for (int i = 0; i < w1.getCodeLength(); ++i) {
            int idx;
            int code = w1.getCodes().get(i);
            int point = w1.getPoints().get(i);
            if (point >= this.syn0.rows() || point < 0) {
                throw new IllegalStateException("Illegal point " + point);
            }
            INDArray syn1 = this.syn1.slice(point);
            double dot = Nd4j.getBlasWrapper().dot(l1, syn1);
            if (dot < -MAX_EXP || dot >= MAX_EXP || (idx = (int)((dot + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))) >= this.expTable.length) continue;
            double f = this.expTable[idx];
            double g = ((double)(1 - code) - f) * (this.useAdaGrad ? w1.getGradient(i, alpha) : alpha);
            if (this.syn0.data().dataType() == 0) {
                Nd4j.getBlasWrapper().axpy(g, syn1, neu1e);
                Nd4j.getBlasWrapper().axpy(g, l1, syn1);
                continue;
            }
            Nd4j.getBlasWrapper().axpy((float)g, syn1, neu1e);
            Nd4j.getBlasWrapper().axpy((float)g, l1, syn1);
        }
        if (this.syn0.data().dataType() == 0) {
            Nd4j.getBlasWrapper().axpy(1.0, neu1e, l1);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, neu1e, l1);
        }
    }

    @Override
    public void resetWeights() {
        this.rng = new MersenneTwister(this.seed);
        this.syn0 = Nd4j.rand((int[])new int[]{this.vocab.numWords() + 1, this.vectorLength}, (RandomGenerator)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
        INDArray randUnk = Nd4j.rand((int)1, (int)this.vectorLength, (RandomGenerator)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
        this.putVector("UNK", randUnk);
        this.syn1 = Nd4j.create((int[])this.syn0.shape());
        this.initNegative();
    }

    protected void makeTable(int tableSize, double power) {
        int vocabSize = this.syn0.rows();
        this.table = Nd4j.create((DataBuffer)new FloatBuffer(tableSize));
        double trainWordsPow = 0.0;
        for (String word : this.vocab.words()) {
            trainWordsPow += Math.pow(this.vocab.wordFrequency(word), power);
        }
        int wordIdx = 0;
        double d1 = Math.pow(this.vocab.wordFrequency(this.vocab.wordAtIndex(wordIdx)), power) / trainWordsPow;
        for (int i = 0; i < tableSize; ++i) {
            this.table.putScalar(i, wordIdx);
            double mul = (double)i * 1.0 / (double)tableSize;
            if (mul > d1) {
                String wordAtIndex;
                if ((wordAtIndex = this.vocab.wordAtIndex(++wordIdx)) == null) continue;
                d1 += Math.pow(this.vocab.wordFrequency(wordAtIndex), power) / trainWordsPow;
            }
            if (wordIdx < vocabSize) continue;
            wordIdx = vocabSize - 1;
        }
    }

    @Override
    public void putVector(String word, INDArray vector) {
        if (word == null) {
            throw new IllegalArgumentException("No null words allowed");
        }
        if (vector == null) {
            throw new IllegalArgumentException("No null vectors allowed");
        }
        int idx = this.vocab.indexOf(word);
        this.syn0.slice(idx).assign(vector);
    }

    public INDArray getTable() {
        return this.table;
    }

    public void setTable(INDArray table) {
        this.table = table;
    }

    public INDArray getSyn1Neg() {
        return this.syn1Neg;
    }

    public void setSyn1Neg(INDArray syn1Neg) {
        this.syn1Neg = syn1Neg;
    }

    @Override
    public INDArray vector(String word) {
        if (word == null) {
            return null;
        }
        int idx = this.vocab.indexOf(word);
        if (idx < 0) {
            idx = this.vocab.indexOf("UNK");
        }
        return this.syn0.getRow(idx);
    }

    @Override
    public void setLearningRate(double lr) {
        this.lr.set(lr);
    }

    @Override
    public Iterator<INDArray> vectors() {
        return new WeightIterator();
    }

    public INDArray getSyn0() {
        return this.syn0;
    }

    public void setSyn0(INDArray syn0) {
        this.syn0 = syn0;
    }

    public INDArray getSyn1() {
        return this.syn1;
    }

    public void setSyn1(INDArray syn1) {
        this.syn1 = syn1;
    }

    public int getVectorLength() {
        return this.vectorLength;
    }

    public void setVectorLength(int vectorLength) {
        this.vectorLength = vectorLength;
    }

    public AtomicDouble getLr() {
        return this.lr;
    }

    public void setLr(AtomicDouble lr) {
        this.lr = lr;
    }

    public VocabCache getVocab() {
        return this.vocab;
    }

    public void setVocab(VocabCache vocab) {
        this.vocab = vocab;
    }

    public Map<Integer, INDArray> getCodes() {
        return this.codes;
    }

    public void setCodes(Map<Integer, INDArray> codes) {
        this.codes = codes;
    }

    public static class Builder {
        protected int vectorLength = 100;
        protected boolean useAdaGrad = false;
        protected double lr = 0.025;
        protected RandomGenerator gen = new XorShift64StarRandomGenerator(123L);
        protected long seed = 123L;
        protected double negative = 0.0;
        protected VocabCache vocabCache;

        public Builder cache(VocabCache vocab) {
            this.vocabCache = vocab;
            return this;
        }

        public Builder negative(double negative) {
            this.negative = negative;
            return this;
        }

        public Builder vectorLength(int vectorLength) {
            this.vectorLength = vectorLength;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder lr(double lr) {
            this.lr = lr;
            return this;
        }

        public Builder gen(RandomGenerator gen) {
            this.gen = gen;
            return this;
        }

        public Builder seed(long seed) {
            this.seed = seed;
            return this;
        }

        public WeightLookupTable build() {
            if (this.vocabCache == null) {
                throw new IllegalStateException("Vocab cache must be specified");
            }
            InMemoryLookupTable ret = new InMemoryLookupTable(this.vocabCache, this.vectorLength, this.useAdaGrad, this.lr, this.gen, this.negative);
            return ret;
        }
    }

    protected class WeightIterator
    implements Iterator<INDArray> {
        protected int currIndex = 0;

        protected WeightIterator() {
        }

        @Override
        public boolean hasNext() {
            return this.currIndex < InMemoryLookupTable.this.syn0.rows();
        }

        @Override
        public INDArray next() {
            INDArray ret = InMemoryLookupTable.this.syn0.slice(this.currIndex);
            ++this.currIndex;
            return ret;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }
}

