/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.learning.impl.sequence;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;

public class DBOW<T extends SequenceElement>
implements SequenceLearningAlgorithm<T> {
    protected VocabCache<T> vocabCache;
    protected WeightLookupTable<T> lookupTable;
    protected VectorsConfiguration configuration;
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected SkipGram<T> skipGram = new SkipGram();

    @Override
    public String getCodeName() {
        return "DBOW";
    }

    @Override
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> lookupTable, @NonNull VectorsConfiguration configuration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache");
        }
        if (lookupTable == null) {
            throw new NullPointerException("lookupTable");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = lookupTable;
        this.window = configuration.getWindow();
        this.useAdaGrad = configuration.isUseAdaGrad();
        this.negative = configuration.getNegative();
        this.skipGram.configure(vocabCache, lookupTable, configuration);
    }

    @Override
    public void pretrain(SequenceIterator<T> iterator) {
    }

    @Override
    public void learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong nextRandom, double learningRate) {
        if (sequence == null) {
            throw new NullPointerException("sequence");
        }
        if (nextRandom == null) {
            throw new NullPointerException("nextRandom");
        }
        for (int i = 0; i < sequence.getElements().size(); ++i) {
            this.dbow(i, sequence, (int)nextRandom.get() % this.window, nextRandom, learningRate);
        }
    }

    @Override
    public boolean isEarlyTerminationHit() {
        return false;
    }

    protected void dbow(int i, Sequence<T> sequence, int b, AtomicLong nextRandom, double alpha) {
        SequenceElement word = (SequenceElement)sequence.getElements().get(i);
        List<T> sentence = sequence.getElements();
        ArrayList<T> labels = new ArrayList<T>();
        labels.add(sequence.getSequenceLabel());
        if (sequence.getSequenceLabel() == null) {
            throw new IllegalStateException("Label is NULL");
        }
        if (word == null || sentence.isEmpty()) {
            return;
        }
        int end = this.window * 2 + 1 - b;
        for (int a = b; a < end; ++a) {
            int c;
            if (a == this.window || (c = i - this.window + a) < 0 || c >= labels.size()) continue;
            SequenceElement lastWord = (SequenceElement)labels.get(c);
            this.skipGram.iterateSample(word, lastWord, nextRandom, alpha);
        }
    }
}

