/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.nn.core.AbstractIndexedEmbedding;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public abstract class Embedding<T>
extends ParameterBlock
implements AbstractIndexedEmbedding<T> {
    private static final byte VERSION = 4;
    protected int embeddingSize;
    protected boolean sparseGrad;
    protected DataType dataType;
    protected Map<T, Integer> embedder;
    protected Map<Integer, T> unembedder;
    protected int numItems;
    protected AbstractIndexedEmbedding<T> fallthroughEmbedding;
    protected Parameter embedding;

    protected Embedding(BaseBuilder<T, ?> baseBuilder) {
        this.embeddingSize = baseBuilder.embeddingSize;
        this.sparseGrad = baseBuilder.sparseGrad;
        this.dataType = baseBuilder.dataType;
        this.embedding = new Parameter("embedding", this, ParameterType.WEIGHT, true, this.sparseGrad ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE);
        this.embedder = new ConcurrentHashMap<T, Integer>();
        this.unembedder = new ConcurrentHashMap<Integer, T>();
        if (baseBuilder.fallthrough != null && baseBuilder.defaultItem != null) {
            throw new IllegalArgumentException("You can not specify both a fallthrough and a defaultItem");
        }
        if (baseBuilder.fallthrough != null) {
            this.fallthroughEmbedding = baseBuilder.fallthrough;
        } else if (baseBuilder.defaultItem != null) {
            this.fallthroughEmbedding = new DefaultItem(baseBuilder.defaultItem);
        } else if (baseBuilder.useDefault) {
            this.fallthroughEmbedding = new DefaultEmbedding();
        }
        this.numItems = 1;
        for (Object item : baseBuilder.items) {
            this.embedder.put(item, this.numItems);
            this.unembedder.put(this.numItems++, item);
        }
        this.inputShapes = new Shape[]{new Shape(-1L)};
    }

    public Embedding(NDArray embedding, List<T> items) {
        this(embedding, items, true);
    }

    public Embedding(NDArray embedding, List<T> items, boolean sparseGrad) {
        this.embeddingSize = Math.toIntExact(embedding.getShape().get(1));
        this.sparseGrad = sparseGrad;
        this.dataType = embedding.getDataType();
        this.embedding = new Parameter("embedding", this, ParameterType.WEIGHT, true, sparseGrad ? SparseFormat.ROW_SPARSE : SparseFormat.DENSE);
        this.embedding.setArray(embedding);
        this.numItems = Math.toIntExact(embedding.getShape().size(0));
        this.embedder = new ConcurrentHashMap<T, Integer>(this.numItems);
        this.unembedder = new ConcurrentHashMap<Integer, T>(this.numItems);
        for (int i = 1; i <= items.size(); ++i) {
            this.embedder.put(items.get(i), i);
            this.unembedder.put(i, items.get(i));
        }
        this.inputShapes = new Shape[]{new Shape(-1L)};
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return new Shape[]{inputShapes[0].addAll(new Shape(this.embeddingSize))};
    }

    @Override
    public List<Parameter> getDirectParameters() {
        return Collections.singletonList(this.embedding);
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        if ("embedding".equals(name)) {
            return new Shape(this.numItems, this.embeddingSize);
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDList opInputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = opInputs.head().getNDArrayInternal();
        NDList result = ex.embedding(opInputs, this.numItems, this.embeddingSize, this.sparseGrad, this.dataType, params);
        if (inputs.head().getShape().dimension() == 0) {
            result = new NDList(result.singletonOrThrow().reshape(this.embeddingSize));
        }
        return result;
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(4);
        this.saveInputShapes(os);
        os.writeBoolean(this.sparseGrad);
        os.writeUTF(this.dataType.toString());
        Set<Map.Entry<T, Integer>> embedderEntrySet = this.embedder.entrySet();
        os.writeInt(embedderEntrySet.size());
        for (Map.Entry<T, Integer> entry : embedderEntrySet) {
            byte[] encodedKey = this.encode(entry.getKey());
            os.writeInt(encodedKey.length);
            os.write(encodedKey);
            os.writeInt(this.embedder.get(entry.getKey()));
        }
        this.embedding.save(os);
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        boolean addMissingZero = false;
        if (version == 4 || version == 3) {
            this.readInputShapes(is);
            if (version == 3) {
                addMissingZero = !is.readBoolean();
            }
            this.sparseGrad = is.readBoolean();
            this.dataType = DataType.valueOf(is.readUTF().toUpperCase(Locale.ENGLISH));
            this.embedder = new ConcurrentHashMap<T, Integer>();
            this.unembedder = new ConcurrentHashMap<Integer, T>();
            int embedderSize = is.readInt();
            for (int i = 1; i <= embedderSize; ++i) {
                int encodedKeySize = is.readInt();
                byte[] encodedKey = new byte[encodedKeySize];
                if (is.read(encodedKey) != encodedKey.length) {
                    throw new MalformedModelException("Model data is malformed");
                }
                int value = is.readInt();
                this.embedder.put(this.decode(encodedKey), value);
                this.unembedder.put(value, this.decode(encodedKey));
            }
        } else if (version == 2) {
            this.readInputShapes(is);
            addMissingZero = true;
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        this.embedding.load(manager, is);
        this.numItems = (int)this.embedding.getArray().getShape().get(0);
        this.embeddingSize = (int)this.embedding.getArray().getShape().get(1);
        if (addMissingZero) {
            ++this.numItems;
            this.embedding.setArray(NDArrays.concat(new NDList(manager.zeros(new Shape(1L, this.embeddingSize)), this.embedding.getArray())));
        }
    }

    @Override
    public boolean hasItem(T item) {
        return this.embedder.containsKey(item);
    }

    private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        NDArray items = inputs.head();
        Device device = items.getDevice();
        NDList ret = new NDList(2);
        if (items.getShape().dimension() == 0) {
            ret.add(items.reshape(1L));
        } else {
            ret.add(items);
        }
        ret.add(parameterStore.getValue(this.embedding, device));
        return ret;
    }

    @Override
    public NDArray embed(NDManager manager, T[] items) {
        return manager.create(Arrays.stream(items).mapToInt(this::embed).toArray());
    }

    @Override
    public int embed(T item) {
        if (this.embedder.containsKey(item)) {
            return this.embedder.get(item);
        }
        if (this.fallthroughEmbedding != null) {
            return this.fallthroughEmbedding.embed(item);
        }
        throw new IllegalArgumentException("The provided item was not found");
    }

    @Override
    public Optional<T> unembed(int index) {
        if (index == 0) {
            if (this.fallthroughEmbedding == null) {
                throw new IllegalArgumentException("Index 0 is reserved for the fallThrough but no fallThrough is found");
            }
            return this.fallthroughEmbedding.unembed(index);
        }
        return Optional.ofNullable(this.unembedder.get(index));
    }

    protected class DefaultItem
    implements AbstractIndexedEmbedding<T> {
        private T defaultItem;

        public DefaultItem(T defaultItem) {
            this.defaultItem = defaultItem;
        }

        @Override
        public byte[] encode(T input) throws IOException {
            return Embedding.this.encode(input);
        }

        @Override
        public T decode(byte[] byteArray) throws IOException {
            return Embedding.this.decode(byteArray);
        }

        @Override
        public boolean hasItem(T item) {
            return true;
        }

        @Override
        public NDArray embed(NDManager manager, T[] items) {
            Object[] defaults = new Object[items.length];
            Arrays.fill(defaults, this.defaultItem);
            return Embedding.this.embed(manager, defaults);
        }

        @Override
        public int embed(T item) {
            return 0;
        }

        @Override
        public Optional<T> unembed(int index) {
            return Optional.of(this.defaultItem);
        }
    }

    protected class DefaultEmbedding
    implements AbstractIndexedEmbedding<T> {
        protected DefaultEmbedding() {
        }

        @Override
        public byte[] encode(T input) throws IOException {
            return Embedding.this.encode(input);
        }

        @Override
        public T decode(byte[] byteArray) throws IOException {
            return Embedding.this.decode(byteArray);
        }

        @Override
        public boolean hasItem(T item) {
            return true;
        }

        @Override
        public NDArray embed(NDManager manager, T[] items) {
            int length = items.length;
            NDArray base = Embedding.this.embedding.getArray().get(0L);
            base.attach(manager);
            return base.repeat(new Shape(length, Embedding.this.embeddingSize));
        }

        @Override
        public int embed(T item) {
            return 0;
        }

        @Override
        public Optional<T> unembed(int index) {
            return Optional.empty();
        }
    }

    public static abstract class BaseBuilder<T, B extends BaseBuilder<T, B>> {
        protected Class<T> embeddingType;
        protected List<T> items = new ArrayList<T>();
        protected int embeddingSize;
        protected boolean useDefault = true;
        protected T defaultItem;
        protected AbstractIndexedEmbedding<T> fallthrough;
        protected boolean sparseGrad = true;
        protected DataType dataType = DataType.FLOAT32;

        protected BaseBuilder() {
        }

        public Class<T> getEmbeddingType() {
            return this.embeddingType;
        }

        protected abstract B setType(Class<T> var1);

        public B setItems(List<T> items) {
            this.items = items;
            return this.self();
        }

        public B setEmbeddingSize(int embeddingSize) {
            this.embeddingSize = embeddingSize;
            return this.self();
        }

        public B optUseDefault(boolean useDefault) {
            this.useDefault = useDefault;
            return this.self();
        }

        public B optDefaultItem(T defaultItem) {
            this.defaultItem = defaultItem;
            return this.self();
        }

        public B optFallthrough(AbstractIndexedEmbedding<T> fallthrough) {
            this.fallthrough = fallthrough;
            return this.self();
        }

        public B optSparseGrad(boolean sparseGrad) {
            this.sparseGrad = sparseGrad;
            return this.self();
        }

        public B optDataType(DataType dataType) {
            this.dataType = dataType;
            return this.self();
        }

        protected abstract B self();
    }
}

