/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.samediff;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.conf.layers.samediff.BaseSameDiffLayer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class SameDiffLayer
extends AbstractLayer<AbstractSameDiffLayer> {
    public static final String INPUT_KEY = "input";
    protected SameDiff sameDiff;
    protected List<String> outputKeys;
    protected INDArray params;
    protected INDArray gradients;
    protected Map<String, INDArray> paramTable;

    public SameDiffLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    @Override
    public Layer clone() {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
    }

    @Override
    public INDArray activate(boolean training) {
        if (this.sameDiff == null) {
            this.doInit();
        }
        this.sameDiff.associateArrayWithVariable(this.input, this.sameDiff.getVariable(INPUT_KEY));
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            INDArray result;
            INDArray iNDArray = result = this.sameDiff.execAndEndResult();
            return iNDArray;
        }
    }

    @Override
    public INDArray preOutput(boolean training) {
        return this.activate(training);
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        throw new UnsupportedOperationException("Fitting DL4J SameDiff layers via backpropagation is not yet supported");
    }

    @Override
    public double calcL2(boolean backpropParamsOnly) {
        double l2Sum = 0.0;
        for (Map.Entry<String, INDArray> entry : this.paramTable().entrySet()) {
            double l2 = this.conf.getL2ByParam(entry.getKey());
            if (!(l2 > 0.0)) continue;
            double norm2 = this.getParam(entry.getKey()).norm2Number().doubleValue();
            l2Sum += 0.5 * l2 * norm2 * norm2;
        }
        return l2Sum;
    }

    @Override
    public double calcL1(boolean backpropParamsOnly) {
        double l1Sum = 0.0;
        for (Map.Entry<String, INDArray> entry : this.paramTable().entrySet()) {
            double l1 = this.conf.getL1ByParam(entry.getKey());
            if (!(l1 > 0.0)) continue;
            double norm1 = this.getParam(entry.getKey()).norm1Number().doubleValue();
            l1Sum += l1 * norm1;
        }
        return l1Sum;
    }

    @Override
    public INDArray params() {
        return this.params;
    }

    @Override
    public INDArray getParam(String param) {
        return this.paramTable.get(param);
    }

    @Override
    public void setParam(String key, INDArray val) {
        if (!this.paramTable.containsKey(key)) {
            throw new IllegalArgumentException("Cannot set parameter, invalid/unknown parameter key: " + key);
        }
        INDArray current = this.paramTable.get(key);
        if (!Arrays.equals(current.shape(), val.shape())) {
            throw new IllegalArgumentException("Cannot set parameter \"" + key + "\", invalid shape: parameter array has shape " + Arrays.toString(current.shape()) + ", trying to set parameter of shape " + Arrays.toString(val.shape()));
        }
    }

    @Override
    public void setParams(INDArray params) {
        if (params != null) {
            throw new UnsupportedOperationException("Not supported");
        }
    }

    @Override
    protected void setParams(INDArray params, char order) {
        this.setParams(params);
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        this.params = params;
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.gradients;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        this.gradients = gradients;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        if (this.paramTable == null) {
            this.paramTable = paramTable;
        } else {
            for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
                this.setParam(e.getKey(), e.getValue());
            }
        }
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        return this.paramTable;
    }

    protected void doInit() {
        BaseSameDiffLayer bl = (BaseSameDiffLayer)this.layerConf();
        this.sameDiff = SameDiff.create();
        Map<String, INDArray> p = this.paramTable();
        int[] inputShape = (int[])this.input.shape().clone();
        SDVariable inputVar = this.sameDiff.var(INPUT_KEY, inputShape);
        Map<String, int[]> paramShapes = ((AbstractSameDiffLayer)this.layerConf()).getLayerParams().getParamShapes();
        LinkedHashMap<String, SDVariable> params = new LinkedHashMap<String, SDVariable>();
        for (String s : paramShapes.keySet()) {
            int[] ps = paramShapes.get(s);
            SDVariable v = this.sameDiff.var(s, ps);
            params.put(s, v);
        }
        List<SDVariable> layerOutputs = bl.defineLayer(this.sameDiff, inputVar, params);
        if (layerOutputs == null || layerOutputs.size() != 1) {
            throw new IllegalStateException("Invalid outputs: " + layerOutputs);
        }
        for (Map.Entry<String, INDArray> e : p.entrySet()) {
            this.sameDiff.associateArrayWithVariable(e.getValue(), this.sameDiff.getVariable(e.getKey()));
        }
        this.outputKeys = new ArrayList<String>();
        for (SDVariable sdv : layerOutputs) {
            this.outputKeys.add(sdv.getVarName());
        }
    }
}

