/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.nlp;

import ai.djl.modality.nlp.Vocabulary;
import ai.djl.util.Utils;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;

public class DefaultVocabulary
implements Vocabulary {
    private Map<String, TokenInfo> tokens = new ConcurrentHashMap<String, TokenInfo>();
    private List<String> indexToToken;
    private Set<String> reservedTokens;
    private String unknownToken;

    public DefaultVocabulary(List<String> tokens) {
        this(DefaultVocabulary.builder().add(tokens));
    }

    public DefaultVocabulary(Builder builder) {
        this.reservedTokens = builder.reservedTokens;
        this.unknownToken = builder.unknownToken;
        if (this.unknownToken != null) {
            this.reservedTokens.add(this.unknownToken);
        }
        for (List<String> sentence : builder.sentences) {
            for (String token : sentence) {
                this.addToken(token);
            }
        }
        for (String token : this.reservedTokens) {
            this.addToken(token);
        }
        boolean pruned = this.pruneTokens(builder.minFrequency, builder.maxTokens);
        if (pruned) {
            this.initializeIndexToTokenReplacingIndices();
        } else {
            this.initializeIndexToTokenKeepingIndices();
        }
    }

    private void addToken(String token) {
        int index = this.tokens.size();
        this.tokens.compute(token, (k, v) -> {
            if (v == null) {
                v = new TokenInfo();
                v.index = index;
            }
            if (this.reservedTokens.contains(k)) {
                v.frequency = Integer.MAX_VALUE;
            } else if (v.frequency < Integer.MAX_VALUE) {
                ++v.frequency;
            }
            return v;
        });
    }

    private boolean pruneTokens(int minFrequency, int maxSize) {
        boolean pruned = false;
        if (minFrequency > 1) {
            for (Map.Entry<String, TokenInfo> token2 : this.tokens.entrySet()) {
                if (token2.getValue().frequency >= minFrequency) continue;
                this.tokens.remove(token2.getKey());
            }
            pruned = true;
        }
        if (maxSize > 0 && this.tokens.size() > maxSize) {
            this.tokens.entrySet().stream().sorted(Map.Entry.comparingByValue(Comparator.comparingInt(tokenInfo -> -tokenInfo.frequency))).skip(maxSize).forEach(token -> this.tokens.remove(token.getKey()));
            pruned = true;
        }
        return pruned;
    }

    private void initializeIndexToTokenKeepingIndices() {
        this.indexToToken = Arrays.asList(new String[this.tokens.size()]);
        for (Map.Entry<String, TokenInfo> token : this.tokens.entrySet()) {
            this.indexToToken.set(Math.toIntExact(token.getValue().index), token.getKey());
        }
    }

    private void initializeIndexToTokenReplacingIndices() {
        this.indexToToken = this.tokens.entrySet().stream().sorted(Comparator.comparingLong(token -> ((TokenInfo)token.getValue()).index)).map(Map.Entry::getKey).collect(Collectors.toList());
        for (int i = 0; i < this.indexToToken.size(); ++i) {
            this.tokens.get((Object)this.indexToToken.get((int)i)).index = i;
        }
    }

    @Override
    public boolean contains(String token) {
        return this.tokens.containsKey(token);
    }

    @Override
    public String getToken(long index) {
        if (index < 0L || index >= (long)this.indexToToken.size()) {
            return this.unknownToken;
        }
        return this.indexToToken.get((int)index);
    }

    @Override
    public long getIndex(String token) {
        if (this.tokens.containsKey(token)) {
            return this.tokens.get((Object)token).index;
        }
        if (this.unknownToken != null) {
            return this.tokens.get((Object)this.unknownToken).index;
        }
        throw new IllegalStateException("Unexpected token in getIndex. Define an unknownToken for the vocabulary to enable support for unknown tokens.");
    }

    @Override
    public long size() {
        return this.tokens.size();
    }

    public static Builder builder() {
        return new Builder();
    }

    private static final class TokenInfo {
        int frequency;
        long index = -1L;

        private TokenInfo() {
        }
    }

    public static final class Builder {
        List<List<String>> sentences = new ArrayList<List<String>>();
        Set<String> reservedTokens = new HashSet<String>();
        int minFrequency = -1;
        int maxTokens = -1;
        String unknownToken;

        private Builder() {
        }

        public Builder optMinFrequency(int minFrequency) {
            this.minFrequency = minFrequency;
            return this;
        }

        public Builder optMaxTokens(int maxTokens) {
            this.maxTokens = maxTokens;
            return this;
        }

        public Builder optUnknownToken() {
            return this.optUnknownToken("<unk>");
        }

        public Builder optUnknownToken(String unknownToken) {
            this.unknownToken = unknownToken;
            return this;
        }

        public Builder optReservedTokens(Collection<String> reservedTokens) {
            this.reservedTokens.addAll(reservedTokens);
            return this;
        }

        public Builder add(List<String> sentence) {
            this.sentences.add(sentence);
            return this;
        }

        public Builder addAll(List<List<String>> sentences) {
            this.sentences.addAll(sentences);
            return this;
        }

        public Builder addFromTextFile(Path path) throws IOException {
            this.add(Utils.readLines(path, true));
            return this;
        }

        public Builder addFromTextFile(URL url) throws IOException {
            try (InputStream is = url.openStream();){
                this.add(Utils.readLines(is, true));
            }
            return this;
        }

        public Builder addFromCustomizedFile(URL url, Function<URL, List<String>> lambda) {
            return this.add(lambda.apply(url));
        }

        public DefaultVocabulary build() {
            if (this.maxTokens > 0 && this.maxTokens < this.reservedTokens.size()) {
                throw new IllegalArgumentException("The vocabulary maxTokens can not be smaller than the number of reserved tokens");
            }
            return new DefaultVocabulary(this);
        }
    }
}

