/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.scaleout.perform.models.word2vec;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.InMemoryLookupCache;
import org.deeplearning4j.scaleout.perform.models.word2vec.Word2VecResult;
import org.nd4j.linalg.api.ndarray.INDArray;

public class Word2VecWork
implements Serializable {
    private Map<String, Pair<VocabWord, INDArray>> vectors = new ConcurrentHashMap<String, Pair<VocabWord, INDArray>>();
    private Map<String, Pair<VocabWord, INDArray>> negativeVectors = new ConcurrentHashMap<String, Pair<VocabWord, INDArray>>();
    private List<List<VocabWord>> sentences;
    private Map<Integer, VocabWord> indexes = new ConcurrentHashMap<Integer, VocabWord>();
    private Map<String, INDArray> originalVectors = new ConcurrentHashMap<String, INDArray>();
    private Map<String, INDArray> originalSyn1Vectors = new ConcurrentHashMap<String, INDArray>();
    private Map<String, INDArray> originalNegative = new ConcurrentHashMap<String, INDArray>();
    private Map<String, INDArray> syn1Vectors = new ConcurrentHashMap<String, INDArray>();

    public Word2VecWork(InMemoryLookupTable table, InMemoryLookupCache cache, List<List<VocabWord>> sentences) {
        this.sentences = sentences;
        for (List<VocabWord> sentence : sentences) {
            for (VocabWord word : sentence) {
                this.addWord(word, table);
                if (word.getPoints() == null) continue;
                for (int i = 0; i < word.getCodeLength(); ++i) {
                    VocabWord pointWord = cache.wordFor(cache.wordAtIndex(word.getPoints().get(i)));
                    this.addWord(pointWord, table);
                }
            }
        }
    }

    private void addWord(VocabWord word, InMemoryLookupTable table) {
        if (word == null) {
            throw new IllegalArgumentException("Word must not be null!");
        }
        this.indexes.put(word.getIndex(), word);
        this.vectors.put(word.getWord(), (Pair<VocabWord, INDArray>)new Pair((Object)word, (Object)table.getSyn0().getRow(word.getIndex()).dup()));
        this.originalVectors.put(word.getWord(), table.getSyn0().getRow(word.getIndex()).dup());
        this.syn1Vectors.put(word.getWord(), table.getSyn1().slice(word.getIndex()).dup());
        this.originalSyn1Vectors.put(word.getWord(), table.getSyn1().slice(word.getIndex()).dup());
        if (table.getSyn1Neg() != null) {
            this.originalNegative.put(word.getWord(), table.getSyn1Neg().slice(word.getIndex()).dup());
            this.negativeVectors.put(word.getWord(), (Pair<VocabWord, INDArray>)new Pair((Object)word, (Object)table.getSyn1Neg().slice(word.getIndex()).dup()));
        }
    }

    public Word2VecResult addDeltas() {
        HashMap<String, INDArray> syn0Change = new HashMap<String, INDArray>();
        HashMap<String, INDArray> syn1Change = new HashMap<String, INDArray>();
        HashMap<String, INDArray> negativeChange = new HashMap<String, INDArray>();
        for (List<VocabWord> sentence : this.sentences) {
            for (VocabWord word : sentence) {
                syn0Change.put(word.getWord(), ((INDArray)this.vectors.get(word.getWord()).getSecond()).sub(this.originalVectors.get(word.getWord())));
                syn1Change.put(word.getWord(), this.syn1Vectors.get(word.getWord()).sub(this.originalSyn1Vectors.get(word.getWord())));
                if (this.negativeVectors.isEmpty()) continue;
                negativeChange.put(word.getWord(), ((INDArray)this.negativeVectors.get(word.getWord()).getSecond()).subi(this.originalNegative.get(word.getWord())));
            }
        }
        return new Word2VecResult(syn0Change, syn1Change, negativeChange);
    }

    public List<List<VocabWord>> getSentences() {
        return this.sentences;
    }

    public void setSentences(List<List<VocabWord>> sentences) {
        this.sentences = sentences;
    }

    public Map<String, Pair<VocabWord, INDArray>> getNegativeVectors() {
        return this.negativeVectors;
    }

    public void setNegativeVectors(Map<String, Pair<VocabWord, INDArray>> negativeVectors) {
        this.negativeVectors = negativeVectors;
    }

    public Map<String, Pair<VocabWord, INDArray>> getVectors() {
        return this.vectors;
    }

    public void setVectors(Map<String, Pair<VocabWord, INDArray>> vectors) {
        this.vectors = vectors;
    }

    public Map<Integer, VocabWord> getIndexes() {
        return this.indexes;
    }

    public void setIndexes(Map<Integer, VocabWord> indexes) {
        this.indexes = indexes;
    }

    public Map<String, INDArray> getOriginalVectors() {
        return this.originalVectors;
    }

    public void setOriginalVectors(Map<String, INDArray> originalVectors) {
        this.originalVectors = originalVectors;
    }

    public Map<String, INDArray> getOriginalSyn1Vectors() {
        return this.originalSyn1Vectors;
    }

    public void setOriginalSyn1Vectors(Map<String, INDArray> originalSyn1Vectors) {
        this.originalSyn1Vectors = originalSyn1Vectors;
    }

    public Map<String, INDArray> getOriginalNegative() {
        return this.originalNegative;
    }

    public void setOriginalNegative(Map<String, INDArray> originalNegative) {
        this.originalNegative = originalNegative;
    }

    public Map<String, INDArray> getSyn1Vectors() {
        return this.syn1Vectors;
    }

    public void setSyn1Vectors(Map<String, INDArray> syn1Vectors) {
        this.syn1Vectors = syn1Vectors;
    }
}

