/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.samediff;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class SameDiffOutputLayer
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer>
implements IOutputLayer {
    public static final String INPUT_KEY = "input";
    public static final String LABELS_KEY = "labels";
    protected SameDiff sameDiff;
    protected SDVariable outputVar;
    protected String outputKey;
    protected INDArray labels;
    protected INDArray params;
    protected INDArray gradients;
    protected Map<String, INDArray> paramTable;
    protected Map<String, INDArray> gradTable;

    public SameDiffOutputLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    public Layer clone() {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return this.activateHelper(true, workspaceMgr);
    }

    private INDArray activateHelper(boolean activations, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        if (activations && INPUT_KEY.equals(((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf()).activationsVertexName())) {
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, this.input);
        }
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            if (this.sameDiff == null) {
                this.doInit();
            }
            for (Map.Entry<String, INDArray> e : this.paramTable.entrySet()) {
                INDArray arr = e.getValue();
                this.sameDiff.assignArray(arr, this.sameDiff.getVariable(e.getKey()));
            }
            HashMap<String, INDArray> phMap = new HashMap<String, INDArray>();
            phMap.put(INPUT_KEY, this.input);
            if (!activations && ((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf()).labelsRequired() && this.labels != null) {
                phMap.put(LABELS_KEY, this.labels);
            }
            String s = activations ? ((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf()).activationsVertexName() : this.outputVar.getVarName();
            INDArray out = this.sameDiff.outputSingle(phMap, s);
            this.sameDiff.clearPlaceholders(true);
            this.sameDiff.clearOpInputs();
            if (activations) {
                Preconditions.checkNotNull((Object)out, (String)"Activations (result) array for variable \"%s\" was null - error during execution or this variable (as defined by method activationsVertexName()) does not exist", (Object)((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf()).activationsVertexName());
                INDArray iNDArray = workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
                return iNDArray;
            }
            INDArray iNDArray = out;
            return iNDArray;
        }
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        INDArray dLdIn;
        this.assertInputSet(true);
        Preconditions.checkState((!((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf()).labelsRequired() || this.labels != null ? 1 : 0) != 0, (String)"Cannot execute backprop: Labels are not set. If labels are not required for this SameDiff output layer, override SameDiffOutputLayer.labelsRequired() to return false instead");
        DefaultGradient g = new DefaultGradient();
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            if (this.sameDiff == null) {
                this.doInit();
            }
            if (!this.sameDiff.hasGradientFunction()) {
                this.sameDiff.createGradFunction(new String[]{INPUT_KEY});
            }
            for (Map.Entry<String, INDArray> entry : this.paramTable.entrySet()) {
                INDArray arr = entry.getValue();
                this.sameDiff.assignArray(arr, this.sameDiff.getVariable(entry.getKey()));
            }
            ArrayList<String> gradVarNames = new ArrayList<String>();
            for (String s : this.paramTable.keySet()) {
                gradVarNames.add(this.sameDiff.getVariable(s).getGradient().getVarName());
            }
            gradVarNames.add(this.sameDiff.grad(INPUT_KEY).getVarName());
            HashMap<String, INDArray> hashMap = new HashMap<String, INDArray>();
            hashMap.put(INPUT_KEY, this.input);
            hashMap.put(LABELS_KEY, this.labels);
            this.sameDiff.execBackwards(hashMap, gradVarNames);
            for (String s : this.paramTable.keySet()) {
                INDArray sdGrad = this.sameDiff.grad(s).getArr();
                INDArray dl4jGrad = this.gradTable.get(s);
                dl4jGrad.assign(sdGrad);
                g.gradientForVariable().put(s, dl4jGrad);
            }
            dLdIn = this.sameDiff.grad(INPUT_KEY).getArr();
        }
        this.sameDiff.clearPlaceholders(true);
        this.sameDiff.clearOpInputs();
        return new Pair((Object)g, (Object)workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn));
    }

    @Override
    public INDArray params() {
        return this.params;
    }

    @Override
    public INDArray getParam(String param) {
        return this.paramTable.get(param);
    }

    @Override
    public long numParams() {
        return this.params == null ? 0L : (long)((int)this.params.length());
    }

    @Override
    public void setParam(String key, INDArray val) {
        if (!this.paramTable.containsKey(key)) {
            throw new IllegalArgumentException("Cannot set parameter, invalid/unknown parameter key: " + key);
        }
        INDArray current = this.paramTable.get(key);
        if (!Arrays.equals(current.shape(), val.shape())) {
            throw new IllegalArgumentException("Cannot set parameter \"" + key + "\", invalid shape: parameter array has shape " + Arrays.toString(current.shape()) + ", trying to set parameter of shape " + Arrays.toString(val.shape()));
        }
    }

    @Override
    public void setParams(INDArray params) {
        if (params != null) {
            throw new UnsupportedOperationException("Not supported");
        }
    }

    @Override
    protected void setParams(INDArray params, char order) {
        this.setParams(params);
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        this.params = params;
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.gradients;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        this.gradients = gradients;
        this.gradTable = ((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf()).initializer().getGradientsFromFlattened(this.conf(), gradients);
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        if (this.paramTable == null) {
            this.paramTable = paramTable;
        } else {
            for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
                this.setParam(e.getKey(), e.getValue());
            }
        }
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        return this.paramTable;
    }

    protected void doInit() {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf();
            this.sameDiff = SameDiff.create();
            Map<String, INDArray> p = this.paramTable();
            long[] inputShape = (long[])this.input.shape().clone();
            inputShape[0] = -1L;
            SDVariable inputVar = this.sameDiff.placeHolder(INPUT_KEY, this.dataType, inputShape);
            SDVariable labelVar = null;
            if (((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf()).labelsRequired()) {
                long[] lArray;
                if (this.labels == null) {
                    long[] lArray2 = new long[2];
                    lArray2[0] = -1L;
                    lArray = lArray2;
                    lArray2[1] = -1L;
                } else {
                    lArray = (long[])this.labels.shape().clone();
                }
                long[] labelShape = lArray;
                labelShape[0] = -1L;
                labelVar = this.sameDiff.placeHolder(LABELS_KEY, this.dataType, labelShape);
            }
            Map<String, long[]> paramShapes = ((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf()).getLayerParams().getParamShapes();
            LinkedHashMap<String, SDVariable> params = new LinkedHashMap<String, SDVariable>();
            for (String s : paramShapes.keySet()) {
                long[] ps = paramShapes.get(s);
                SDVariable v = this.sameDiff.var(s, this.dataType, ps);
                params.put(s, v);
            }
            SDVariable layerOutput = bl.defineLayer(this.sameDiff, inputVar, labelVar, params);
            Preconditions.checkNotNull((Object)layerOutput, (String)"Invalid output: layer output is null");
            this.outputVar = layerOutput;
            for (Map.Entry<String, INDArray> e : p.entrySet()) {
                INDArray arr = e.getValue();
                this.sameDiff.associateArrayWithVariable(arr, this.sameDiff.getVariable(e.getKey()));
            }
            this.outputKey = layerOutput.getVarName();
        }
    }

    @Override
    public boolean needsLabels() {
        return ((org.deeplearning4j.nn.conf.layers.samediff.SameDiffOutputLayer)this.layerConf()).labelsRequired();
    }

    @Override
    public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
        return (this.activateHelper(false, workspaceMgr).getDouble(0L) + fullNetRegTerm) / (double)this.input.size(0);
    }

    @Override
    public INDArray computeScoreForExamples(double fullNetRegTerm, LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public double f1Score(DataSet data) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public double f1Score(INDArray examples, INDArray labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public int numLabels() {
        return 0;
    }

    @Override
    public void fit(DataSetIterator iter) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public int[] predict(INDArray examples) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public List<String> predict(DataSet dataSet) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(INDArray examples, INDArray labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(DataSet data) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public INDArray getLabels() {
        return this.labels;
    }

    @Override
    public void setLabels(INDArray labels) {
        this.labels = labels;
    }
}

