/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.reader.impl;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.reader.ModelUtils;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Counter;
import org.nd4j.linalg.util.MathUtils;
import org.nd4j.util.SetUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BasicModelUtils<T extends SequenceElement>
implements ModelUtils<T> {
    private static final Logger log = LoggerFactory.getLogger(BasicModelUtils.class);
    public static final String EXISTS = "exists";
    public static final String CORRECT = "correct";
    public static final String WRONG = "wrong";
    protected volatile VocabCache<T> vocabCache;
    protected volatile WeightLookupTable<T> lookupTable;
    protected volatile boolean normalized = false;

    @Override
    public void init(@NonNull WeightLookupTable<T> lookupTable) {
        if (lookupTable == null) {
            throw new NullPointerException("lookupTable is marked @NonNull but is null");
        }
        this.vocabCache = lookupTable.getVocabCache();
        this.lookupTable = lookupTable;
        this.normalized = false;
    }

    @Override
    public double similarity(@NonNull String label1, @NonNull String label2) {
        if (label1 == null) {
            throw new NullPointerException("label1 is marked @NonNull but is null");
        }
        if (label2 == null) {
            throw new NullPointerException("label2 is marked @NonNull but is null");
        }
        if (label1 == null || label2 == null) {
            log.debug("LABELS: " + label1 + ": " + (label1 == null ? "null" : EXISTS) + ";" + label2 + " vec2:" + (label2 == null ? "null" : EXISTS));
            return Double.NaN;
        }
        if (!this.vocabCache.hasToken(label1)) {
            log.debug("Unknown token 1 requested: [{}]", (Object)label1);
            return Double.NaN;
        }
        if (!this.vocabCache.hasToken(label2)) {
            log.debug("Unknown token 2 requested: [{}]", (Object)label2);
            return Double.NaN;
        }
        INDArray vec1 = this.lookupTable.vector(label1).dup();
        INDArray vec2 = this.lookupTable.vector(label2).dup();
        if (vec1 == null || vec2 == null) {
            log.debug(label1 + ": " + (vec1 == null ? "null" : EXISTS) + ";" + label2 + " vec2:" + (vec2 == null ? "null" : EXISTS));
            return Double.NaN;
        }
        if (label1.equals(label2)) {
            return 1.0;
        }
        return Transforms.cosineSim((INDArray)vec1, (INDArray)vec2);
    }

    @Override
    public Collection<String> wordsNearest(String label, int n) {
        ArrayList<String> collection = new ArrayList<String>(this.wordsNearest(Arrays.asList(label), new ArrayList<String>(), n + 1));
        if (collection.contains(label)) {
            collection.remove(label);
        }
        while (collection.size() > n) {
            collection.remove(collection.size() - 1);
        }
        return collection;
    }

    @Override
    public Map<String, Double> accuracy(List<String> questions) {
        HashMap<String, Double> accuracy = new HashMap<String, Double>();
        Counter right = new Counter();
        String analogyType = "";
        for (String s : questions) {
            if (s.startsWith(":")) {
                double correct = right.getCount((Object)CORRECT);
                double wrong = right.getCount((Object)WRONG);
                if (analogyType.isEmpty()) {
                    analogyType = s;
                    continue;
                }
                double accuracyRet = 100.0 * correct / (correct + wrong);
                accuracy.put(analogyType, accuracyRet);
                analogyType = s;
                right.clear();
                continue;
            }
            String[] split = s.split(" ");
            List<String> positive = Arrays.asList(split[1], split[2]);
            String predicted = split[3];
            List<String> negative = Arrays.asList(split[0]);
            String w = this.wordsNearest(positive, negative, 1).iterator().next();
            if (predicted.equals(w)) {
                right.incrementCount((Object)CORRECT, 1.0);
                continue;
            }
            right.incrementCount((Object)WRONG, 1.0);
        }
        if (!analogyType.isEmpty()) {
            double correct = right.getCount((Object)CORRECT);
            double wrong = right.getCount((Object)WRONG);
            double accuracyRet = 100.0 * correct / (correct + wrong);
            accuracy.put(analogyType, accuracyRet);
        }
        return accuracy;
    }

    @Override
    public List<String> similarWordsInVocabTo(String word, double accuracy) {
        ArrayList<String> ret = new ArrayList<String>();
        for (String s : this.vocabCache.words()) {
            String[] stringArray = new String[]{word, s};
            if (!(MathUtils.stringSimilarity((String[])stringArray) >= accuracy)) continue;
            ret.add(s);
        }
        return ret;
    }

    @Override
    public Collection<String> wordsNearest(@NonNull Collection<String> positive, @NonNull Collection<String> negative, int top) {
        if (positive == null) {
            throw new NullPointerException("positive is marked @NonNull but is null");
        }
        if (negative == null) {
            throw new NullPointerException("negative is marked @NonNull but is null");
        }
        for (String p : SetUtils.union(new HashSet<String>(positive), new HashSet<String>(negative))) {
            if (this.vocabCache.containsWord(p)) continue;
            return new ArrayList<String>();
        }
        INDArray words = Nd4j.create((int)(positive.size() + negative.size()), (int)this.lookupTable.layerSize());
        int row = 0;
        for (String s : positive) {
            words.putRow((long)row++, this.lookupTable.vector(s));
        }
        for (String s : negative) {
            words.putRow((long)row++, this.lookupTable.vector(s).mul((Number)-1));
        }
        INDArray mean = words.isMatrix() ? words.mean(new int[]{0}) : words;
        Collection<String> tempRes = this.wordsNearest(mean, top + positive.size() + negative.size());
        ArrayList<String> realResults = new ArrayList<String>();
        for (String word : tempRes) {
            if (positive.contains(word) || negative.contains(word) || realResults.size() >= top) continue;
            realResults.add(word);
        }
        return realResults;
    }

    @Override
    public Collection<String> wordsNearestSum(String word, int n) {
        INDArray vec = this.lookupTable.vector(word);
        return this.wordsNearestSum(vec, n);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Collection<String> wordsNearest(INDArray words, int top) {
        if (this.lookupTable instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)this.lookupTable;
            INDArray syn0 = l.getSyn0();
            if (!this.normalized) {
                BasicModelUtils basicModelUtils = this;
                synchronized (basicModelUtils) {
                    if (!this.normalized) {
                        syn0.diviColumnVector(syn0.norm2(new int[]{1}));
                        this.normalized = true;
                    }
                }
            }
            INDArray similarity = Transforms.unitVec((INDArray)words).mmul(syn0.transpose());
            List<Double> highToLowSimList = this.getTopN(similarity, top + 20);
            ArrayList<WordSimilarity> result = new ArrayList<WordSimilarity>();
            for (int i = 0; i < highToLowSimList.size(); ++i) {
                String word = this.vocabCache.wordAtIndex(highToLowSimList.get(i).intValue());
                if (word == null || word.equals("UNK") || word.equals("STOP")) continue;
                INDArray otherVec = this.lookupTable.vector(word);
                double sim = Transforms.cosineSim((INDArray)words, (INDArray)otherVec);
                result.add(new WordSimilarity(word, sim));
            }
            Collections.sort(result, new SimilarityComparator());
            return BasicModelUtils.getLabels(result, top);
        }
        Counter distances = new Counter();
        for (String s : this.vocabCache.words()) {
            INDArray otherVec = this.lookupTable.vector(s);
            double sim = Transforms.cosineSim((INDArray)words, (INDArray)otherVec);
            distances.incrementCount((Object)s, (double)((float)sim));
        }
        distances.keepTopNElements(top);
        return distances.keySet();
    }

    private List<Double> getTopN(INDArray vec, int N) {
        ArrayComparator comparator = new ArrayComparator();
        PriorityQueue<Double[]> queue = new PriorityQueue<Double[]>(vec.rows(), comparator);
        int j = 0;
        while ((long)j < vec.length()) {
            Double[] pair = new Double[]{vec.getDouble((long)j), j};
            if (queue.size() < N) {
                queue.add(pair);
            } else {
                Double[] head = queue.peek();
                if (comparator.compare(pair, head) > 0) {
                    queue.poll();
                    queue.add(pair);
                }
            }
            ++j;
        }
        ArrayList<Double> lowToHighSimLst = new ArrayList<Double>();
        while (!queue.isEmpty()) {
            double ind = queue.poll()[1];
            lowToHighSimLst.add(ind);
        }
        return Lists.reverse(lowToHighSimLst);
    }

    @Override
    public Collection<String> wordsNearestSum(INDArray words, int top) {
        if (this.lookupTable instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)this.lookupTable;
            INDArray syn0 = l.getSyn0();
            INDArray weights = syn0.norm2(new int[]{0}).rdivi((Number)1).muli(words);
            INDArray distances = syn0.mulRowVector(weights).sum(new int[]{1});
            INDArray[] sorted = Nd4j.sortWithIndices((INDArray)distances, (int)0, (boolean)false);
            INDArray sort = sorted[0];
            ArrayList<String> ret = new ArrayList<String>();
            if ((long)top > sort.length()) {
                top = (int)sort.length();
            }
            int end = top;
            for (int i = 0; i < end; ++i) {
                String add = this.vocabCache.wordAtIndex(sort.getInt(new int[]{i}));
                if (add == null || add.equals("UNK") || add.equals("STOP")) {
                    if ((long)(++end) < sort.length()) continue;
                    break;
                }
                ret.add(this.vocabCache.wordAtIndex(sort.getInt(new int[]{i})));
            }
            return ret;
        }
        Counter distances = new Counter();
        for (String s : this.vocabCache.words()) {
            INDArray otherVec = this.lookupTable.vector(s);
            double sim = Transforms.cosineSim((INDArray)words, (INDArray)otherVec);
            distances.incrementCount((Object)s, (double)((float)sim));
        }
        distances.keepTopNElements(top);
        return distances.keySet();
    }

    @Override
    public Collection<String> wordsNearestSum(Collection<String> positive, Collection<String> negative, int top) {
        INDArray words = Nd4j.create((int)this.lookupTable.layerSize());
        for (String s : positive) {
            words.addi(this.lookupTable.vector(s));
        }
        for (String s : negative) {
            words.addi(this.lookupTable.vector(s).mul((Number)-1));
        }
        return this.wordsNearestSum(words, top);
    }

    public static List<String> getLabels(List<WordSimilarity> results, int limit) {
        ArrayList<String> result = new ArrayList<String>();
        for (int x = 0; x < results.size(); ++x) {
            result.add(results.get(x).getWord());
            if (result.size() >= limit) break;
        }
        return result;
    }

    public static class WordSimilarity {
        private String word;
        private double similarity;

        public String getWord() {
            return this.word;
        }

        public double getSimilarity() {
            return this.similarity;
        }

        public void setWord(String word) {
            this.word = word;
        }

        public void setSimilarity(double similarity) {
            this.similarity = similarity;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof WordSimilarity)) {
                return false;
            }
            WordSimilarity other = (WordSimilarity)o;
            if (!other.canEqual(this)) {
                return false;
            }
            String this$word = this.getWord();
            String other$word = other.getWord();
            if (this$word == null ? other$word != null : !this$word.equals(other$word)) {
                return false;
            }
            return Double.compare(this.getSimilarity(), other.getSimilarity()) == 0;
        }

        protected boolean canEqual(Object other) {
            return other instanceof WordSimilarity;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            String $word = this.getWord();
            result = result * 59 + ($word == null ? 43 : $word.hashCode());
            long $similarity = Double.doubleToLongBits(this.getSimilarity());
            result = result * 59 + (int)($similarity >>> 32 ^ $similarity);
            return result;
        }

        public String toString() {
            return "BasicModelUtils.WordSimilarity(word=" + this.getWord() + ", similarity=" + this.getSimilarity() + ")";
        }

        public WordSimilarity(String word, double similarity) {
            this.word = word;
            this.similarity = similarity;
        }
    }

    public static class ArrayComparator
    implements Comparator<Double[]> {
        @Override
        public int compare(Double[] o1, Double[] o2) {
            if (Double.isNaN(o1[0]) && Double.isNaN(o2[0])) {
                return 0;
            }
            if (Double.isNaN(o1[0]) && !Double.isNaN(o2[0])) {
                return -1;
            }
            if (!Double.isNaN(o1[0]) && Double.isNaN(o2[0])) {
                return 1;
            }
            return Double.compare(o1[0], o2[0]);
        }
    }

    public static class SimilarityComparator
    implements Comparator<WordSimilarity> {
        @Override
        public int compare(WordSimilarity o1, WordSimilarity o2) {
            if (Double.isNaN(o1.getSimilarity()) && Double.isNaN(o2.getSimilarity())) {
                return 0;
            }
            if (Double.isNaN(o1.getSimilarity()) && !Double.isNaN(o2.getSimilarity())) {
                return -1;
            }
            if (!Double.isNaN(o1.getSimilarity()) && Double.isNaN(o2.getSimilarity())) {
                return 1;
            }
            return Double.compare(o2.getSimilarity(), o1.getSimilarity());
        }
    }
}

