/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.ArrayList;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasLoss
extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasLoss.class);
    private final String KERAS_CLASS_NAME_LOSS = "Loss";
    private ILossFunction loss;

    public KerasLoss(String layerName, String inboundLayerName, String kerasLoss) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        this(layerName, inboundLayerName, kerasLoss, true);
    }

    public KerasLoss(String layerName, String inboundLayerName, String kerasLoss, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        this.className = "Loss";
        this.layerName = layerName;
        this.inputShape = null;
        this.dimOrder = KerasLayer.DimOrder.NONE;
        this.inboundLayerNames = new ArrayList();
        this.inboundLayerNames.add(inboundLayerName);
        try {
            this.loss = KerasLossUtils.mapLossFunction(kerasLoss, this.conf);
        }
        catch (UnsupportedKerasConfigurationException e) {
            if (enforceTrainingConfig) {
                throw e;
            }
            log.warn("Unsupported Keras loss function. Replacing with MSE.");
            this.loss = LossFunctions.LossFunction.SQUARED_LOSS.getILossFunction();
        }
    }

    public FeedForwardLayer getLossLayer(InputType type) throws UnsupportedKerasConfigurationException {
        if (type instanceof InputType.InputTypeFeedForward) {
            this.layer = ((LossLayer.Builder)((LossLayer.Builder)new LossLayer.Builder(this.loss).name(this.layerName)).activation(Activation.IDENTITY)).build();
        } else if (type instanceof InputType.InputTypeRecurrent) {
            this.layer = ((RnnLossLayer.Builder)((RnnLossLayer.Builder)new RnnLossLayer.Builder(this.loss).name(this.layerName)).activation(Activation.IDENTITY)).build();
        } else if (type instanceof InputType.InputTypeConvolutional) {
            this.layer = ((CnnLossLayer.Builder)((CnnLossLayer.Builder)new CnnLossLayer.Builder(this.loss).name(this.layerName)).activation(Activation.IDENTITY)).build();
        } else {
            throw new UnsupportedKerasConfigurationException("Unsupported output layer typegot : " + type.toString());
        }
        return (FeedForwardLayer)this.layer;
    }

    @Override
    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (inputType.length > 1) {
            throw new InvalidKerasConfigurationException("Keras Loss layer accepts only one input (received " + inputType.length + ")");
        }
        return this.getLossLayer(inputType[0]).getOutputType(-1, inputType[0]);
    }

    public String getKERAS_CLASS_NAME_LOSS() {
        return this.KERAS_CLASS_NAME_LOSS;
    }

    public ILossFunction getLoss() {
        return this.loss;
    }

    public void setLoss(ILossFunction loss) {
        this.loss = loss;
    }

    public String toString() {
        return "KerasLoss(KERAS_CLASS_NAME_LOSS=" + this.getKERAS_CLASS_NAME_LOSS() + ", loss=" + this.getLoss() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof KerasLoss)) {
            return false;
        }
        KerasLoss other = (KerasLoss)o;
        if (!other.canEqual(this)) {
            return false;
        }
        String this$KERAS_CLASS_NAME_LOSS = this.getKERAS_CLASS_NAME_LOSS();
        String other$KERAS_CLASS_NAME_LOSS = other.getKERAS_CLASS_NAME_LOSS();
        if (this$KERAS_CLASS_NAME_LOSS == null ? other$KERAS_CLASS_NAME_LOSS != null : !this$KERAS_CLASS_NAME_LOSS.equals(other$KERAS_CLASS_NAME_LOSS)) {
            return false;
        }
        ILossFunction this$loss = this.getLoss();
        ILossFunction other$loss = other.getLoss();
        return !(this$loss == null ? other$loss != null : !this$loss.equals(other$loss));
    }

    protected boolean canEqual(Object other) {
        return other instanceof KerasLoss;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        String $KERAS_CLASS_NAME_LOSS = this.getKERAS_CLASS_NAME_LOSS();
        result = result * 59 + ($KERAS_CLASS_NAME_LOSS == null ? 43 : $KERAS_CLASS_NAME_LOSS.hashCode());
        ILossFunction $loss = this.getLoss();
        result = result * 59 + ($loss == null ? 43 : $loss.hashCode());
        return result;
    }
}

