/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.glove;

import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;

public class GloveWeightLookupTable
extends InMemoryLookupTable {
    private AdaGrad weightAdaGrad;
    private AdaGrad biasAdaGrad;
    private INDArray bias;
    private double xMax = 0.75;
    private double maxCount = 100.0;

    public GloveWeightLookupTable(VocabCache vocab, int vectorLength, boolean useAdaGrad, double lr, RandomGenerator gen, double negative, double xMax, double maxCount) {
        super(vocab, vectorLength, useAdaGrad, lr, gen, negative);
        this.xMax = xMax;
        this.maxCount = maxCount;
    }

    @Override
    public void resetWeights(boolean reset) {
        if (this.rng == null) {
            this.rng = new MersenneTwister(this.seed);
        }
        if (this.syn0 == null || this.syn0 != null && reset) {
            this.syn0 = Nd4j.rand((int[])new int[]{this.vocab.numWords() + 1, this.vectorLength}, (RandomGenerator)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
            INDArray randUnk = Nd4j.rand((int)1, (int)this.vectorLength, (RandomGenerator)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
            this.putVector("UNK", randUnk);
        }
        if (this.weightAdaGrad == null || this.weightAdaGrad != null && reset) {
            this.weightAdaGrad = new AdaGrad(new int[]{this.vocab.numWords() + 1, this.vectorLength});
            this.weightAdaGrad.setMasterStepSize(this.lr.get());
        }
        if (this.bias == null || this.bias != null && reset) {
            this.bias = Nd4j.create((int)this.syn0.rows());
        }
        if (this.biasAdaGrad == null || this.biasAdaGrad != null && reset) {
            this.biasAdaGrad = new AdaGrad(this.bias.shape());
            this.biasAdaGrad.setMasterStepSize(this.lr.get());
        }
    }

    @Override
    public void resetWeights() {
        this.resetWeights(true);
    }

    public double iterateSample(VocabWord w1, VocabWord w2, double score) {
        double fDiff;
        INDArray w1Vector = this.syn0.slice(w1.getIndex());
        INDArray w2Vector = this.syn0.slice(w2.getIndex());
        if (w1.getIndex() < 0 || w1.getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + w1.getWord());
        }
        if (w2.getIndex() < 0 || w2.getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + w2.getWord());
        }
        double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
        double weight = Math.pow(Math.min(1.0, score / this.maxCount), this.xMax);
        double d = fDiff = score > this.xMax ? prediction : weight * ((prediction += this.bias.getDouble(w1.getIndex()) + this.bias.getDouble(w2.getIndex())) - Math.log(score));
        if (Double.isNaN(fDiff)) {
            fDiff = Nd4j.EPS_THRESHOLD;
        }
        double gradient = fDiff;
        this.update(w1, w1Vector, w2Vector, gradient);
        this.update(w2, w2Vector, w1Vector, gradient);
        return fDiff;
    }

    private void update(VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) {
        INDArray grad1 = contextVector.mul((Number)gradient);
        INDArray update = this.weightAdaGrad.getGradient(grad1, w1.getIndex(), this.syn0.shape());
        wordVector.subi(update);
        double w1Bias = this.bias.getDouble(w1.getIndex());
        double biasGradient = this.biasAdaGrad.getGradient(gradient, w1.getIndex(), this.bias.shape());
        double update2 = w1Bias - biasGradient;
        this.bias.putScalar(w1.getIndex(), update2);
    }

    public AdaGrad getWeightAdaGrad() {
        return this.weightAdaGrad;
    }

    public void setWeightAdaGrad(AdaGrad weightAdaGrad) {
        this.weightAdaGrad = weightAdaGrad;
    }

    public AdaGrad getBiasAdaGrad() {
        return this.biasAdaGrad;
    }

    public void setBiasAdaGrad(AdaGrad biasAdaGrad) {
        this.biasAdaGrad = biasAdaGrad;
    }

    public static GloveWeightLookupTable loadRawArray(InputStream is, VocabCache vocab, int vectorLength) throws IOException {
        GloveWeightLookupTable ret = new Builder().cache(vocab).vectorLength(vectorLength).build();
        INDArray syn0 = Nd4j.readTxt((InputStream)is, (String)" ");
        ret.setSyn0(syn0);
        ret.resetWeights(false);
        return ret;
    }

    public static GloveWeightLookupTable load(InputStream is, VocabCache vocab) throws IOException {
        LineIterator iter = IOUtils.lineIterator((InputStream)is, (String)"UTF-8");
        InMemoryLookupTable glove = null;
        HashMap<String, float[]> wordVectors = new HashMap<String, float[]>();
        while (iter.hasNext()) {
            float[] read;
            String line = iter.nextLine().trim();
            if (line.isEmpty()) continue;
            String[] split = line.split(" ");
            String word = split[0];
            if (glove == null) {
                glove = new Builder().cache(vocab).vectorLength(split.length - 1).build();
            }
            if (word.isEmpty() || (read = GloveWeightLookupTable.read(split, glove.getVectorLength())).length < 1) continue;
            wordVectors.put(word, read);
        }
        glove.setSyn0(GloveWeightLookupTable.weights(glove, wordVectors, vocab));
        ((GloveWeightLookupTable)glove).resetWeights(false);
        iter.close();
        return glove;
    }

    private static INDArray weights(GloveWeightLookupTable glove, Map<String, float[]> data, VocabCache vocab) {
        INDArray ret = Nd4j.create((int)data.size(), (int)glove.getVectorLength());
        for (String key : data.keySet()) {
            INDArray row = Nd4j.create((DataBuffer)Nd4j.createBuffer((float[])data.get(key)));
            if (row.length() != glove.getVectorLength() || vocab.indexOf(key) >= data.size() || vocab.indexOf(key) < 0) continue;
            ret.putRow(vocab.indexOf(key), row);
        }
        return ret;
    }

    private static float[] read(String[] split, int length) {
        float[] ret = new float[length];
        for (int i = 1; i < split.length; ++i) {
            ret[i - 1] = Float.parseFloat(split[i]);
        }
        return ret;
    }

    @Override
    public void iterateSample(VocabWord w1, VocabWord w2, AtomicLong nextRandom, double alpha) {
        throw new UnsupportedOperationException();
    }

    public double getxMax() {
        return this.xMax;
    }

    public void setxMax(double xMax) {
        this.xMax = xMax;
    }

    public double getMaxCount() {
        return this.maxCount;
    }

    public void setMaxCount(double maxCount) {
        this.maxCount = maxCount;
    }

    public INDArray getBias() {
        return this.bias;
    }

    public void setBias(INDArray bias) {
        this.bias = bias;
    }

    public static class Builder
    extends InMemoryLookupTable.Builder {
        private double xMax = 0.75;
        private double maxCount = 100.0;

        public Builder maxCount(double maxCount) {
            this.maxCount = maxCount;
            return this;
        }

        public Builder xMax(double xMax) {
            this.xMax = xMax;
            return this;
        }

        @Override
        public Builder cache(VocabCache vocab) {
            super.cache(vocab);
            return this;
        }

        @Override
        public Builder negative(double negative) {
            super.negative(negative);
            return this;
        }

        @Override
        public Builder vectorLength(int vectorLength) {
            super.vectorLength(vectorLength);
            return this;
        }

        @Override
        public Builder useAdaGrad(boolean useAdaGrad) {
            super.useAdaGrad(useAdaGrad);
            return this;
        }

        @Override
        public Builder lr(double lr) {
            super.lr(lr);
            return this;
        }

        @Override
        public Builder gen(RandomGenerator gen) {
            super.gen(gen);
            return this;
        }

        @Override
        public Builder seed(long seed) {
            super.seed(seed);
            return this;
        }

        @Override
        public GloveWeightLookupTable build() {
            return new GloveWeightLookupTable(this.vocabCache, this.vectorLength, this.useAdaGrad, this.lr, this.gen, this.negative, this.xMax, this.maxCount);
        }
    }
}

