/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers.embeddings;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.EmbeddingLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasConstraintUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasInitilizationUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasEmbedding
extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasEmbedding.class);
    private final int NUM_TRAINABLE_PARAMS = 1;
    private boolean hasZeroMasking;

    public KerasEmbedding() throws UnsupportedKerasConfigurationException {
    }

    public KerasEmbedding(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, true);
    }

    public KerasEmbedding(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(layerConfig, enforceTrainingConfig);
        int inputDim = this.getInputDimFromConfig(layerConfig);
        int[] inputShapeOld = this.inputShape;
        this.inputShape = new int[inputShapeOld.length + 1];
        this.inputShape[0] = inputShapeOld[0];
        this.inputShape[1] = inputDim;
        this.hasZeroMasking = KerasLayerUtils.getZeroMaskingFromConfig(layerConfig, this.conf);
        if (this.hasZeroMasking) {
            log.warn("Masking in keras and DL4J work differently. We do not completely support mask_zero flag on Embedding layers. Zero Masking for the Embedding layer only works with unidirectional LSTM for now. If you want to have this behaviour for your imported model in DL4J, apply masking as a pre-processing step to your input.See https://deeplearning4j.org/usingrnns#masking for more on this.");
        }
        Pair<WeightInit, Distribution> init = KerasInitilizationUtils.getWeightInitFromConfig(layerConfig, this.conf.getLAYER_FIELD_EMBEDDING_INIT(), enforceTrainingConfig, this.conf, this.kerasMajorVersion);
        WeightInit weightInit = (WeightInit)init.getFirst();
        Distribution distribution = (Distribution)init.getSecond();
        LayerConstraint embeddingConstraint = KerasConstraintUtils.getConstraintsFromConfig(layerConfig, this.conf.getLAYER_FIELD_EMBEDDINGS_CONSTRAINT(), this.conf, this.kerasMajorVersion);
        EmbeddingLayer.Builder builder = ((EmbeddingLayer.Builder)((EmbeddingLayer.Builder)((EmbeddingLayer.Builder)((EmbeddingLayer.Builder)((EmbeddingLayer.Builder)((EmbeddingLayer.Builder)((EmbeddingLayer.Builder)((EmbeddingLayer.Builder)((EmbeddingLayer.Builder)new EmbeddingLayer.Builder().name(this.layerName)).nIn(inputDim)).nOut(KerasLayerUtils.getNOutFromConfig(layerConfig, this.conf))).dropOut(this.dropout)).activation(Activation.IDENTITY)).weightInit(weightInit)).biasInit(0.0)).l1(this.weightL1Regularization)).l2(this.weightL2Regularization)).hasBias(false);
        if (distribution != null) {
            builder.dist(distribution);
        }
        if (embeddingConstraint != null) {
            builder.constrainWeights(new LayerConstraint[]{embeddingConstraint});
        }
        this.layer = builder.build();
    }

    public EmbeddingLayer getEmbeddingLayer() {
        return (EmbeddingLayer)this.layer;
    }

    @Override
    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException {
        InputPreProcessor preprocessor = this.getInputPreprocessor(inputType[0]);
        if (preprocessor != null) {
            return this.getEmbeddingLayer().getOutputType(-1, preprocessor.getOutputType(inputType[0]));
        }
        return this.getEmbeddingLayer().getOutputType(-1, inputType[0]);
    }

    @Override
    public int getNumParams() {
        return 1;
    }

    @Override
    public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
        this.weights = new HashMap();
        if (!weights.containsKey(this.conf.getLAYER_FIELD_EMBEDDING_WEIGHTS())) {
            throw new InvalidKerasConfigurationException("Parameter " + this.conf.getLAYER_FIELD_EMBEDDING_WEIGHTS() + " does not exist in weights");
        }
        INDArray kernel = weights.get(this.conf.getLAYER_FIELD_EMBEDDING_WEIGHTS());
        if (this.hasZeroMasking) {
            kernel.putRow(0, Nd4j.zeros((int)kernel.columns()));
        }
        this.weights.put("W", kernel);
        if (weights.size() > 2) {
            Set<String> paramNames = weights.keySet();
            paramNames.remove(this.conf.getLAYER_FIELD_EMBEDDING_WEIGHTS());
            String unknownParamNames = paramNames.toString();
            log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
        }
    }

    private int getInputDimFromConfig(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException {
        Map<String, Object> innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, this.conf);
        if (!innerConfig.containsKey(this.conf.getLAYER_FIELD_INPUT_DIM())) {
            throw new InvalidKerasConfigurationException("Keras Embedding layer config missing " + this.conf.getLAYER_FIELD_INPUT_DIM() + " field");
        }
        return (Integer)innerConfig.get(this.conf.getLAYER_FIELD_INPUT_DIM());
    }

    public int getNUM_TRAINABLE_PARAMS() {
        return this.NUM_TRAINABLE_PARAMS;
    }

    public boolean isHasZeroMasking() {
        return this.hasZeroMasking;
    }

    public void setHasZeroMasking(boolean hasZeroMasking) {
        this.hasZeroMasking = hasZeroMasking;
    }

    public String toString() {
        return "KerasEmbedding(NUM_TRAINABLE_PARAMS=" + this.getNUM_TRAINABLE_PARAMS() + ", hasZeroMasking=" + this.isHasZeroMasking() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof KerasEmbedding)) {
            return false;
        }
        KerasEmbedding other = (KerasEmbedding)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getNUM_TRAINABLE_PARAMS() != other.getNUM_TRAINABLE_PARAMS()) {
            return false;
        }
        return this.isHasZeroMasking() == other.isHasZeroMasking();
    }

    protected boolean canEqual(Object other) {
        return other instanceof KerasEmbedding;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getNUM_TRAINABLE_PARAMS();
        result = result * 59 + (this.isHasZeroMasking() ? 79 : 97);
        return result;
    }
}

