/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.ArrayList;
import java.util.Map;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasLoss
extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasLoss.class);
    public static final String KERAS_CLASS_NAME_LOSS = "Loss";

    public KerasLoss(String layerName, String inboundLayerName, String kerasLoss) throws UnsupportedKerasConfigurationException {
        this(layerName, inboundLayerName, kerasLoss, true);
    }

    public KerasLoss(String layerName, String inboundLayerName, String kerasLoss, boolean enforceTrainingConfig) throws UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction loss;
        this.className = KERAS_CLASS_NAME_LOSS;
        this.layerName = layerName;
        this.inputShape = null;
        this.dimOrder = KerasLayer.DimOrder.NONE;
        this.inboundLayerNames = new ArrayList();
        this.inboundLayerNames.add(inboundLayerName);
        try {
            loss = KerasLoss.mapLossFunction(kerasLoss);
        }
        catch (UnsupportedKerasConfigurationException e) {
            if (enforceTrainingConfig) {
                throw e;
            }
            log.warn("Unsupported Keras loss function. Replacing with MSE.");
            loss = LossFunctions.LossFunction.SQUARED_LOSS;
        }
        this.layer = ((LossLayer.Builder)new LossLayer.Builder(loss).name(layerName)).build();
    }

    private KerasLoss(Map<String, Object> layerConfig) {
    }

    private KerasLoss(Map<String, Object> layerConfig, boolean enforceTrainingConfig) {
    }

    public LossLayer getLossLayer() {
        return (LossLayer)this.layer;
    }

    @Override
    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException {
        if (inputType.length > 1) {
            throw new InvalidKerasConfigurationException("Keras Loss layer accepts only one input (received " + inputType.length + ")");
        }
        return this.getLossLayer().getOutputType(-1, inputType[0]);
    }
}

