/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.ArrayList;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLossUtils;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasLoss
extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasLoss.class);
    private final String KERAS_CLASS_NAME_LOSS = "Loss";

    public KerasLoss(String layerName, String inboundLayerName, String kerasLoss) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        this(layerName, inboundLayerName, kerasLoss, true);
    }

    public KerasLoss(String layerName, String inboundLayerName, String kerasLoss, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        LossFunctions.LossFunction loss;
        this.className = "Loss";
        this.layerName = layerName;
        this.inputShape = null;
        this.dimOrder = KerasLayer.DimOrder.NONE;
        this.inboundLayerNames = new ArrayList();
        this.inboundLayerNames.add(inboundLayerName);
        try {
            loss = KerasLossUtils.mapLossFunction(kerasLoss, this.conf);
        }
        catch (UnsupportedKerasConfigurationException e) {
            if (enforceTrainingConfig) {
                throw e;
            }
            log.warn("Unsupported Keras loss function. Replacing with MSE.");
            loss = LossFunctions.LossFunction.SQUARED_LOSS;
        }
        this.layer = ((LossLayer.Builder)new LossLayer.Builder(loss).name(layerName)).build();
    }

    public LossLayer getLossLayer() {
        return (LossLayer)this.layer;
    }

    @Override
    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException {
        if (inputType.length > 1) {
            throw new InvalidKerasConfigurationException("Keras Loss layer accepts only one input (received " + inputType.length + ")");
        }
        return this.getLossLayer().getOutputType(-1, inputType[0]);
    }

    public String getKERAS_CLASS_NAME_LOSS() {
        return this.KERAS_CLASS_NAME_LOSS;
    }

    public String toString() {
        return "KerasLoss(KERAS_CLASS_NAME_LOSS=" + this.getKERAS_CLASS_NAME_LOSS() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof KerasLoss)) {
            return false;
        }
        KerasLoss other = (KerasLoss)o;
        if (!other.canEqual(this)) {
            return false;
        }
        String this$KERAS_CLASS_NAME_LOSS = this.getKERAS_CLASS_NAME_LOSS();
        String other$KERAS_CLASS_NAME_LOSS = other.getKERAS_CLASS_NAME_LOSS();
        return !(this$KERAS_CLASS_NAME_LOSS == null ? other$KERAS_CLASS_NAME_LOSS != null : !this$KERAS_CLASS_NAME_LOSS.equals(other$KERAS_CLASS_NAME_LOSS));
    }

    protected boolean canEqual(Object other) {
        return other instanceof KerasLoss;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        String $KERAS_CLASS_NAME_LOSS = this.getKERAS_CLASS_NAME_LOSS();
        result = result * 59 + ($KERAS_CLASS_NAME_LOSS == null ? 43 : $KERAS_CLASS_NAME_LOSS.hashCode());
        return result;
    }
}

