/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.AbstractStorage;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StaticWord2Vec
implements WordVectors {
    private static final Logger log = LoggerFactory.getLogger(StaticWord2Vec.class);
    private List<Map<Integer, INDArray>> cacheWrtDevice = new ArrayList<Map<Integer, INDArray>>();
    private AbstractStorage<Integer> storage;
    private long cachePerDevice = 0L;
    private VocabCache<VocabWord> vocabCache;
    private String unk = null;

    private StaticWord2Vec() {
    }

    @Override
    public String getUNK() {
        return this.unk;
    }

    @Override
    public void setUNK(String newUNK) {
        this.unk = newUNK;
    }

    protected void init() {
        if (this.storage.size() != (long)this.vocabCache.numWords()) {
            throw new RuntimeException("Number of words in Vocab isn't matching number of stored Vectors. vocab: [" + this.vocabCache.numWords() + "]; storage: [" + this.storage.size() + "]");
        }
        for (int i = 0; i < Nd4j.getAffinityManager().getNumberOfDevices(); ++i) {
            this.cacheWrtDevice.add(new ConcurrentHashMap());
        }
    }

    @Override
    public boolean hasWord(String word) {
        return this.vocabCache.containsWord(word);
    }

    @Override
    public Collection<String> wordsNearest(INDArray words, int top) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override
    public Collection<String> wordsNearestSum(INDArray words, int top) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override
    public Collection<String> wordsNearestSum(String word, int n) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override
    public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override
    public Map<String, Double> accuracy(List<String> questions) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override
    public int indexOf(String word) {
        return this.vocabCache.indexOf(word);
    }

    @Override
    public List<String> similarWordsInVocabTo(String word, double accuracy) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override
    public double[] getWordVector(String word) {
        return this.getWordVectorMatrix(word).data().asDouble();
    }

    @Override
    public INDArray getWordVectorMatrixNormalized(String word) {
        return Transforms.unitVec((INDArray)this.getWordVectorMatrix(word));
    }

    @Override
    public INDArray getWordVectorMatrix(String word) {
        long arrayBytes;
        int idx = 0;
        if (this.hasWord(word)) {
            idx = this.vocabCache.indexOf(word);
        } else if (this.getUNK() != null) {
            idx = this.vocabCache.indexOf(this.getUNK());
        } else {
            return null;
        }
        int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        INDArray array = null;
        if (this.cachePerDevice > 0L && this.cacheWrtDevice.get(deviceId).containsKey(idx)) {
            return this.cacheWrtDevice.get(Nd4j.getAffinityManager().getDeviceForCurrentThread()).get(idx);
        }
        array = this.storage.get((Object)idx);
        if (this.cachePerDevice > 0L && (arrayBytes = array.length() * (long)array.data().getElementSize()) * (long)this.cacheWrtDevice.get(deviceId).size() + arrayBytes < this.cachePerDevice) {
            this.cacheWrtDevice.get(deviceId).put(idx, array);
        }
        return array;
    }

    @Override
    public INDArray getWordVectors(Collection<String> labels) {
        ArrayList<INDArray> words = new ArrayList<INDArray>();
        for (String label : labels) {
            if (!this.hasWord(label) && this.getUNK() == null) continue;
            words.add(this.getWordVectorMatrix(label));
        }
        return Nd4j.vstack(words);
    }

    @Override
    public INDArray getWordVectorsMean(Collection<String> labels) {
        INDArray matrix = this.getWordVectors(labels);
        return matrix.mean(new int[]{1});
    }

    @Override
    public Collection<String> wordsNearest(Collection<String> positive, Collection<String> negative, int top) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override
    public Collection<String> wordsNearest(String word, int n) {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override
    public double similarity(String label1, String label2) {
        if (label1 == null || label2 == null) {
            log.debug("LABELS: " + label1 + ": " + (label1 == null ? "null" : "exists") + ";" + label2 + " vec2:" + (label2 == null ? "null" : "exists"));
            return Double.NaN;
        }
        INDArray vec1 = this.getWordVectorMatrix(label1).dup();
        INDArray vec2 = this.getWordVectorMatrix(label2).dup();
        if (vec1 == null || vec2 == null) {
            log.debug(label1 + ": " + (vec1 == null ? "null" : "exists") + ";" + label2 + " vec2:" + (vec2 == null ? "null" : "exists"));
            return Double.NaN;
        }
        if (label1.equals(label2)) {
            return 1.0;
        }
        vec1 = Transforms.unitVec((INDArray)vec1);
        vec2 = Transforms.unitVec((INDArray)vec2);
        return Transforms.cosineSim((INDArray)vec1, (INDArray)vec2);
    }

    @Override
    public VocabCache vocab() {
        return this.vocabCache;
    }

    @Override
    public WeightLookupTable lookupTable() {
        throw new UnsupportedOperationException("Method isn't implemented. Please use usual Word2Vec implementation");
    }

    @Override
    public void setModelUtils(ModelUtils utils) {
    }

    public static class Builder {
        private AbstractStorage<Integer> storage;
        private long cachePerDevice = 0L;
        private VocabCache<VocabWord> vocabCache;

        public Builder(AbstractStorage<Integer> storage, VocabCache<VocabWord> vocabCache) {
            this.storage = storage;
            this.vocabCache = vocabCache;
        }

        public Builder setCachePerDevice(long bytes) {
            this.cachePerDevice = bytes;
            return this;
        }

        public StaticWord2Vec build() {
            StaticWord2Vec word2Vec = new StaticWord2Vec();
            word2Vec.cachePerDevice = this.cachePerDevice;
            word2Vec.storage = this.storage;
            word2Vec.vocabCache = this.vocabCache;
            word2Vec.init();
            return word2Vec;
        }
    }
}

