/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.learning.impl.elements;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.impl.elements.BatchItem;
import org.deeplearning4j.models.embeddings.learning.impl.elements.BatchSequences;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SkipGram<T extends SequenceElement>
implements ElementsLearningAlgorithm<T> {
    private static final Logger log = LoggerFactory.getLogger(SkipGram.class);
    protected VocabCache<T> vocabCache;
    protected WeightLookupTable<T> lookupTable;
    protected VectorsConfiguration configuration;
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected double sampling;
    protected int[] variableWindows;
    protected int vectorLength;
    protected int workers = Runtime.getRuntime().availableProcessors();
    protected DeviceLocalNDArray syn0;
    protected DeviceLocalNDArray syn1;
    protected DeviceLocalNDArray syn1Neg;
    protected DeviceLocalNDArray table;
    protected DeviceLocalNDArray expTable;
    protected ThreadLocal<List<Aggregate>> batches = new ThreadLocal();

    public int getWorkers() {
        return this.workers;
    }

    public void setWorkers(int workers) {
        this.workers = workers;
    }

    public List<Aggregate> getBatch() {
        return this.batches.get();
    }

    @Override
    public String getCodeName() {
        return "SkipGram";
    }

    @Override
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache is marked @NonNull but is null");
        }
        if (lookupTable == null) {
            throw new NullPointerException("lookupTable is marked @NonNull but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = lookupTable;
        this.configuration = configuration;
        if (configuration.getNegative() > 0.0 && ((InMemoryLookupTable)lookupTable).getSyn1Neg() == null) {
            log.info("Initializing syn1Neg...");
            ((InMemoryLookupTable)lookupTable).setUseHS(configuration.isUseHierarchicSoftmax());
            ((InMemoryLookupTable)lookupTable).setNegative(configuration.getNegative());
            ((InMemoryLookupTable)lookupTable).resetWeights(false);
        }
        this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable)lookupTable).getSyn0());
        this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable)lookupTable).getSyn1());
        this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable)lookupTable).getSyn1Neg());
        this.expTable = new DeviceLocalNDArray(Nd4j.create((double[])((InMemoryLookupTable)lookupTable).getExpTable(), (long[])new long[]{((InMemoryLookupTable)lookupTable).getExpTable().length}, (DataType)this.syn0.get().dataType()));
        this.table = new DeviceLocalNDArray(((InMemoryLookupTable)lookupTable).getTable());
        this.window = configuration.getWindow();
        this.useAdaGrad = configuration.isUseAdaGrad();
        this.negative = configuration.getNegative();
        this.sampling = configuration.getSampling();
        this.variableWindows = configuration.getVariableWindows();
        this.vectorLength = configuration.getLayersSize();
    }

    @Override
    public void pretrain(SequenceIterator<T> iterator) {
    }

    public Sequence<T> applySubsampling(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom) {
        if (sequence == null) {
            throw new NullPointerException("sequence is marked @NonNull but is null");
        }
        if (nextRandom == null) {
            throw new NullPointerException("nextRandom is marked @NonNull but is null");
        }
        Sequence<T> result = new Sequence<T>();
        if (this.sampling > 0.0) {
            result.setSequenceId(sequence.getSequenceId());
            if (sequence.getSequenceLabels() != null) {
                result.setSequenceLabels(sequence.getSequenceLabels());
            }
            if (sequence.getSequenceLabel() != null) {
                result.setSequenceLabel(sequence.getSequenceLabel());
            }
            for (SequenceElement element : sequence.getElements()) {
                double numWords = this.vocabCache.totalWordOccurrences();
                double ran = (Math.sqrt(element.getElementFrequency() / (this.sampling * numWords)) + 1.0) * (this.sampling * numWords) / element.getElementFrequency();
                nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11L));
                if (ran < (double)(nextRandom.get() & 0xFFFFL) / 65536.0) continue;
                result.addElement(element);
            }
            return result;
        }
        return sequence;
    }

    @Override
    public double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom, double learningRate, BatchSequences<T> batchSequences) {
        if (sequence == null) {
            throw new NullPointerException("sequence is marked @NonNull but is null");
        }
        if (nextRandom == null) {
            throw new NullPointerException("nextRandom is marked @NonNull but is null");
        }
        Sequence<T> tempSequence = sequence;
        if (this.sampling > 0.0) {
            tempSequence = this.applySubsampling(sequence, nextRandom);
        }
        double score = 0.0;
        int currentWindow = this.window;
        if (this.variableWindows != null && this.variableWindows.length != 0) {
            currentWindow = this.variableWindows[RandomUtils.nextInt((int)0, (int)this.variableWindows.length)];
        }
        for (int i = 0; i < tempSequence.getElements().size(); ++i) {
            nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11L));
            score = this.skipGram(i, tempSequence.getElements(), (int)nextRandom.get() % currentWindow, nextRandom, learningRate, currentWindow, batchSequences);
        }
        if (this.batches != null && this.batches.get() != null && this.batches.get().size() >= this.configuration.getBatchSize()) {
            Nd4j.getExecutioner().exec(this.batches.get());
            this.batches.get().clear();
        }
        return score;
    }

    @Override
    public double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom, double learningRate) {
        if (sequence == null) {
            throw new NullPointerException("sequence is marked @NonNull but is null");
        }
        if (nextRandom == null) {
            throw new NullPointerException("nextRandom is marked @NonNull but is null");
        }
        Sequence<T> tempSequence = sequence;
        if (this.sampling > 0.0) {
            tempSequence = this.applySubsampling(sequence, nextRandom);
        }
        double score = 0.0;
        int currentWindow = this.window;
        if (this.variableWindows != null && this.variableWindows.length != 0) {
            currentWindow = this.variableWindows[RandomUtils.nextInt((int)0, (int)this.variableWindows.length)];
        }
        for (int i = 0; i < tempSequence.getElements().size(); ++i) {
            nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11L));
            score = this.skipGram(i, tempSequence.getElements(), (int)nextRandom.get() % currentWindow, nextRandom, learningRate, currentWindow);
        }
        if (this.batches != null && this.batches.get() != null && this.batches.get().size() >= this.configuration.getBatchSize()) {
            Nd4j.getExecutioner().exec(this.batches.get());
            this.batches.get().clear();
        }
        return score;
    }

    @Override
    public void finish() {
        if (this.batches != null && this.batches.get() != null && !this.batches.get().isEmpty()) {
            Nd4j.getExecutioner().exec(this.batches.get());
            this.batches.get().clear();
        }
    }

    @Override
    public boolean isEarlyTerminationHit() {
        return false;
    }

    private double skipGram(int i, List<T> sentence, int b, AtomicLong nextRandom, double alpha, int currentWindow) {
        SequenceElement word = (SequenceElement)sentence.get(i);
        if (word == null || sentence.isEmpty()) {
            return 0.0;
        }
        double score = 0.0;
        int batchSize = this.configuration.getBatchSize();
        int end = currentWindow * 2 + 1 - b;
        for (int a = b; a < end; ++a) {
            int c;
            if (a == currentWindow || (c = i - currentWindow + a) < 0 || c >= sentence.size()) continue;
            SequenceElement lastWord = (SequenceElement)sentence.get(c);
            score = this.iterateSample(word, lastWord, nextRandom, alpha, false, null);
        }
        return score;
    }

    private double skipGram(int i, List<T> sentence, int b, AtomicLong nextRandom, double alpha, int currentWindow, BatchSequences<T> batchSequences) {
        SequenceElement word = (SequenceElement)sentence.get(i);
        if (word == null || sentence.isEmpty() || word.isLocked()) {
            return 0.0;
        }
        double score = 0.0;
        int batchSize = this.configuration.getBatchSize();
        int end = currentWindow * 2 + 1 - b;
        for (int a = b; a < end; ++a) {
            int c;
            if (a == currentWindow || (c = i - currentWindow + a) < 0 || c >= sentence.size()) continue;
            SequenceElement lastWord = (SequenceElement)sentence.get(c);
            nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11L));
            if (batchSize <= 1) {
                score = this.iterateSample(word, lastWord, nextRandom, alpha, false, null);
                continue;
            }
            batchSequences.put(word, lastWord, nextRandom.get(), alpha);
        }
        return score;
    }

    public double iterateSample(T w1, T lastWord, AtomicLong nextRandom, double alpha, boolean isInference, INDArray inferenceVector) {
        if (w1 == null || lastWord == null || ((SequenceElement)lastWord).getIndex() < 0 && !isInference || ((SequenceElement)w1).getIndex() == ((SequenceElement)lastWord).getIndex() || ((SequenceElement)w1).getLabel().equals("STOP") || ((SequenceElement)lastWord).getLabel().equals("STOP") || ((SequenceElement)w1).getLabel().equals("UNK") || ((SequenceElement)lastWord).getLabel().equals("UNK")) {
            return 0.0;
        }
        double score = 0.0;
        int[] idxSyn1 = null;
        byte[] codes = null;
        if (this.configuration.isUseHierarchicSoftmax()) {
            idxSyn1 = new int[((SequenceElement)w1).getCodeLength()];
            codes = new byte[((SequenceElement)w1).getCodeLength()];
            for (int i = 0; i < ((SequenceElement)w1).getCodeLength(); ++i) {
                byte code = ((SequenceElement)w1).getCodes().get(i);
                int point = ((SequenceElement)w1).getPoints().get(i);
                if (point >= this.vocabCache.numWords() || point < 0) continue;
                codes[i] = code;
                idxSyn1[i] = point;
            }
        } else {
            idxSyn1 = new int[]{};
            codes = new byte[]{};
        }
        int target = ((SequenceElement)w1).getIndex();
        if (this.negative > 0.0 && this.syn1Neg == null) {
            ((InMemoryLookupTable)this.lookupTable).initNegative();
            this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable)this.lookupTable).getSyn1Neg());
        }
        if (this.batches.get() == null) {
            this.batches.set(new ArrayList());
        }
        nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11L));
        SkipGramRound sg = null;
        boolean useHS = this.configuration.isUseHierarchicSoftmax();
        boolean useNegative = this.configuration.getNegative() > 0.0;
        int[] intCodes = new int[codes.length];
        for (int i = 0; i < codes.length; ++i) {
            intCodes[i] = codes[i];
        }
        if (useHS && useNegative) {
            sg = new SkipGramRound(Nd4j.scalar((int)((SequenceElement)lastWord).getIndex()), Nd4j.scalar((int)target), this.syn0.get(), this.syn1.get(), this.syn1Neg.get(), this.expTable.get(), this.table.get(), (int)this.negative, Nd4j.create((int[])idxSyn1), Nd4j.create((int[])intCodes), Nd4j.scalar((double)alpha), Nd4j.scalar((long)nextRandom.get()), inferenceVector != null ? inferenceVector : Nd4j.empty((DataType)this.syn0.get().dataType()), this.configuration.isPreciseMode(), this.workers);
        } else if (useHS) {
            sg = new SkipGramRound(((SequenceElement)lastWord).getIndex(), this.syn0.get(), this.syn1.get(), this.expTable.get(), idxSyn1, codes, alpha, nextRandom.get(), inferenceVector != null ? inferenceVector : Nd4j.empty((DataType)this.syn0.get().dataType()));
        } else if (useNegative) {
            sg = new SkipGramRound(((SequenceElement)lastWord).getIndex(), target, this.syn0.get(), this.syn1Neg.get(), this.expTable.get(), this.table.get(), (int)this.negative, alpha, nextRandom.get(), inferenceVector != null ? inferenceVector : Nd4j.empty((DataType)this.syn0.get().dataType()));
        }
        Nd4j.getExecutioner().exec(sg);
        return score;
    }

    public double iterateSample(List<BatchItem<T>> items) {
        boolean useHS = this.configuration.isUseHierarchicSoftmax();
        boolean useNegative = this.configuration.getNegative() > 0.0;
        double score = 0.0;
        boolean isInference = false;
        int[] targets = new int[items.size()];
        int[] starters = new int[items.size()];
        double[] alphas = new double[items.size()];
        long[] randomValues = new long[items.size()];
        int maxCols = 1;
        for (int i = 0; i < items.size(); ++i) {
            int curr = ((SequenceElement)items.get(i).getWord()).getCodeLength();
            if (curr <= maxCols) continue;
            maxCols = curr;
        }
        byte[][] codes = new byte[items.size()][maxCols];
        int[][] indices = new int[items.size()][maxCols];
        for (int cnt = 0; cnt < items.size(); ++cnt) {
            T w1 = items.get(cnt).getWord();
            T lastWord = items.get(cnt).getLastWord();
            randomValues[cnt] = items.get(cnt).getRandomValue();
            double alpha = items.get(cnt).getAlpha();
            if (w1 == null || lastWord == null || ((SequenceElement)lastWord).getIndex() < 0 && !isInference || ((SequenceElement)w1).getIndex() == ((SequenceElement)lastWord).getIndex() || ((SequenceElement)w1).getLabel().equals("STOP") || ((SequenceElement)lastWord).getLabel().equals("STOP") || ((SequenceElement)w1).getLabel().equals("UNK") || ((SequenceElement)lastWord).getLabel().equals("UNK")) continue;
            int target = ((SequenceElement)lastWord).getIndex();
            int ngStarter = ((SequenceElement)w1).getIndex();
            targets[cnt] = target;
            starters[cnt] = ngStarter;
            alphas[cnt] = alpha;
            int[] idxSyn1 = null;
            byte[] interimCodes = null;
            if (useHS) {
                int i;
                idxSyn1 = new int[((SequenceElement)w1).getCodeLength()];
                interimCodes = new byte[((SequenceElement)w1).getCodeLength()];
                for (i = 0; i < ((SequenceElement)w1).getCodeLength(); ++i) {
                    byte code = ((SequenceElement)w1).getCodes().get(i);
                    int point = ((SequenceElement)w1).getPoints().get(i);
                    if (point >= this.vocabCache.numWords() || point < 0) continue;
                    interimCodes[i] = code;
                    idxSyn1[i] = point;
                }
                for (i = 0; i < maxCols; ++i) {
                    codes[cnt][i] = i < ((SequenceElement)w1).getCodeLength() ? interimCodes[i] : -1;
                }
                for (i = 0; i < maxCols; ++i) {
                    indices[cnt][i] = i < ((SequenceElement)w1).getCodeLength() ? idxSyn1[i] : -1;
                }
            } else {
                idxSyn1 = new int[]{};
                interimCodes = new byte[]{};
                codes = new byte[0][0];
                indices = new int[0][0];
            }
            if (!(this.negative > 0.0) || this.syn1Neg != null) continue;
            ((InMemoryLookupTable)this.lookupTable).initNegative();
            this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable)this.lookupTable).getSyn1Neg());
        }
        INDArray targetArray = Nd4j.createFromArray((int[])targets);
        INDArray ngStarterArray = Nd4j.createFromArray((int[])starters);
        INDArray alphasArray = Nd4j.createFromArray((double[])alphas);
        INDArray randomValuesArray = Nd4j.createFromArray((long[])randomValues);
        INDArray indicesArray = Nd4j.createFromArray((int[][])indices);
        INDArray codesArray = Nd4j.createFromArray((byte[][])codes);
        SkipGramRound sg = new SkipGramRound(targetArray, this.negative > 0.0 ? ngStarterArray : Nd4j.empty((DataType)DataType.INT), this.syn0.get(), useHS ? this.syn1.get() : Nd4j.empty((DataType)this.syn0.get().dataType()), this.negative > 0.0 ? this.syn1Neg.get() : Nd4j.empty((DataType)this.syn0.get().dataType()), this.expTable.get(), this.negative > 0.0 ? this.table.get() : Nd4j.empty((DataType)this.syn0.get().dataType()), (int)this.negative, useHS ? indicesArray : Nd4j.empty((DataType)DataType.INT), useHS ? codesArray : Nd4j.empty((DataType)DataType.BYTE), alphasArray, randomValuesArray, Nd4j.empty((DataType)this.syn0.get().dataType()), this.configuration.isPreciseMode(), this.workers);
        Nd4j.getExecutioner().exec((CustomOp)sg);
        return score;
    }

    public DeviceLocalNDArray getSyn0() {
        return this.syn0;
    }

    public DeviceLocalNDArray getSyn1() {
        return this.syn1;
    }

    public DeviceLocalNDArray getSyn1Neg() {
        return this.syn1Neg;
    }

    public DeviceLocalNDArray getTable() {
        return this.table;
    }

    public DeviceLocalNDArray getExpTable() {
        return this.expTable;
    }

    public void setSyn0(DeviceLocalNDArray syn0) {
        this.syn0 = syn0;
    }

    public void setSyn1(DeviceLocalNDArray syn1) {
        this.syn1 = syn1;
    }

    public void setSyn1Neg(DeviceLocalNDArray syn1Neg) {
        this.syn1Neg = syn1Neg;
    }

    public void setTable(DeviceLocalNDArray table) {
        this.table = table;
    }

    public void setExpTable(DeviceLocalNDArray expTable) {
        this.expTable = expTable;
    }
}

