/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.glove;

import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;

@Deprecated
public class GloveWeightLookupTable<T extends SequenceElement>
extends InMemoryLookupTable<T> {
    private AdaGrad weightAdaGrad;
    private AdaGrad biasAdaGrad;
    private INDArray bias;
    private double xMax = 0.75;
    private double maxCount = 100.0;

    public GloveWeightLookupTable(VocabCache<T> vocab, int vectorLength, boolean useAdaGrad, double lr, Random gen, double negative, double xMax, double maxCount) {
        super(vocab, vectorLength, useAdaGrad, lr, gen, negative);
        this.xMax = xMax;
        this.maxCount = maxCount;
    }

    @Override
    public void resetWeights(boolean reset) {
        if (this.rng == null) {
            this.rng = Nd4j.getRandom();
        }
        if (this.syn0 == null || reset) {
            this.syn0 = Nd4j.rand((int[])new int[]{this.vocab.numWords() + 1, this.vectorLength}, (Random)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
            INDArray randUnk = Nd4j.rand((int)1, (int)this.vectorLength, (Random)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
            this.putVector("UNK", randUnk);
        }
        if (this.weightAdaGrad == null || reset) {
            this.weightAdaGrad = new AdaGrad(new int[]{this.vocab.numWords() + 1, this.vectorLength}, this.lr.get());
        }
        if (this.bias == null || reset) {
            this.bias = Nd4j.create((int)this.syn0.rows());
        }
        if (this.biasAdaGrad == null || reset) {
            this.biasAdaGrad = new AdaGrad(this.bias.shape(), this.lr.get());
        }
    }

    @Override
    public void resetWeights() {
        this.resetWeights(true);
    }

    public double iterateSample(T w1, T w2, double score) {
        double fDiff;
        INDArray w1Vector = this.syn0.slice(((SequenceElement)w1).getIndex());
        INDArray w2Vector = this.syn0.slice(((SequenceElement)w2).getIndex());
        if (((SequenceElement)w1).getIndex() < 0 || ((SequenceElement)w1).getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + ((SequenceElement)w1).getLabel());
        }
        if (((SequenceElement)w2).getIndex() < 0 || ((SequenceElement)w2).getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + ((SequenceElement)w2).getLabel());
        }
        double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
        double weight = Math.pow(Math.min(1.0, score / this.maxCount), this.xMax);
        double d = fDiff = score > this.xMax ? prediction : weight * ((prediction += this.bias.getDouble(((SequenceElement)w1).getIndex()) + this.bias.getDouble(((SequenceElement)w2).getIndex())) - Math.log(score));
        if (Double.isNaN(fDiff)) {
            fDiff = Nd4j.EPS_THRESHOLD;
        }
        double gradient = fDiff;
        this.update(w1, w1Vector, w2Vector, gradient);
        this.update(w2, w2Vector, w1Vector, gradient);
        return fDiff;
    }

    private void update(T w1, INDArray wordVector, INDArray contextVector, double gradient) {
        INDArray grad1 = contextVector.mul((Number)gradient);
        INDArray update = this.weightAdaGrad.getGradient(grad1, ((SequenceElement)w1).getIndex(), this.syn0.shape());
        wordVector.subi(update);
        double w1Bias = this.bias.getDouble(((SequenceElement)w1).getIndex());
        double biasGradient = this.biasAdaGrad.getGradient(gradient, ((SequenceElement)w1).getIndex(), this.bias.shape());
        double update2 = w1Bias - biasGradient;
        this.bias.putScalar(((SequenceElement)w1).getIndex(), update2);
    }

    public AdaGrad getWeightAdaGrad() {
        return this.weightAdaGrad;
    }

    public AdaGrad getBiasAdaGrad() {
        return this.biasAdaGrad;
    }

    public static GloveWeightLookupTable load(InputStream is, VocabCache<? extends SequenceElement> vocab) throws IOException {
        LineIterator iter = IOUtils.lineIterator((InputStream)is, (String)"UTF-8");
        WeightLookupTable glove = null;
        HashMap<String, float[]> wordVectors = new HashMap<String, float[]>();
        while (iter.hasNext()) {
            float[] read;
            String line = iter.nextLine().trim();
            if (line.isEmpty()) continue;
            String[] split = line.split(" ");
            String word = split[0];
            if (glove == null) {
                glove = ((Builder)((Builder)new Builder().cache((VocabCache)vocab)).vectorLength(split.length - 1)).build();
            }
            if (word.isEmpty() || (read = GloveWeightLookupTable.read(split, glove.layerSize())).length < 1) continue;
            wordVectors.put(word, read);
        }
        glove.setSyn0(GloveWeightLookupTable.weights(glove, wordVectors, vocab));
        ((GloveWeightLookupTable)glove).resetWeights(false);
        iter.close();
        return glove;
    }

    private static INDArray weights(GloveWeightLookupTable glove, Map<String, float[]> data, VocabCache vocab) {
        INDArray ret = Nd4j.create((int)data.size(), (int)glove.layerSize());
        for (Map.Entry<String, float[]> entry : data.entrySet()) {
            String key = entry.getKey();
            INDArray row = Nd4j.create((DataBuffer)Nd4j.createBuffer((float[])entry.getValue()));
            if (row.length() != glove.layerSize() || vocab.indexOf(key) >= data.size() || vocab.indexOf(key) < 0) continue;
            ret.putRow(vocab.indexOf(key), row);
        }
        return ret;
    }

    private static float[] read(String[] split, int length) {
        float[] ret = new float[length];
        for (int i = 1; i < split.length; ++i) {
            ret[i - 1] = Float.parseFloat(split[i]);
        }
        return ret;
    }

    @Override
    public void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha) {
        throw new UnsupportedOperationException();
    }

    public double getxMax() {
        return this.xMax;
    }

    public void setxMax(double xMax) {
        this.xMax = xMax;
    }

    public double getMaxCount() {
        return this.maxCount;
    }

    public void setMaxCount(double maxCount) {
        this.maxCount = maxCount;
    }

    public INDArray getBias() {
        return this.bias;
    }

    public void setBias(INDArray bias) {
        this.bias = bias;
    }

    public static class Builder<T extends SequenceElement>
    extends InMemoryLookupTable.Builder<T> {
        private double xMax = 0.75;
        private double maxCount = 100.0;

        public Builder<T> maxCount(double maxCount) {
            this.maxCount = maxCount;
            return this;
        }

        public Builder<T> xMax(double xMax) {
            this.xMax = xMax;
            return this;
        }

        @Override
        public Builder<T> cache(VocabCache<T> vocab) {
            super.cache(vocab);
            return this;
        }

        @Override
        public Builder<T> negative(double negative) {
            super.negative(negative);
            return this;
        }

        @Override
        public Builder<T> vectorLength(int vectorLength) {
            super.vectorLength(vectorLength);
            return this;
        }

        @Override
        public Builder<T> useAdaGrad(boolean useAdaGrad) {
            super.useAdaGrad(useAdaGrad);
            return this;
        }

        @Override
        public Builder<T> lr(double lr) {
            super.lr(lr);
            return this;
        }

        @Override
        public Builder<T> gen(Random gen) {
            super.gen(gen);
            return this;
        }

        @Override
        public Builder<T> seed(long seed) {
            super.seed(seed);
            return this;
        }

        @Override
        public GloveWeightLookupTable<T> build() {
            return new GloveWeightLookupTable(this.vocabCache, this.vectorLength, this.useAdaGrad, this.lr, this.gen, this.negative, this.xMax, this.maxCount);
        }
    }
}

