/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.samediff;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.layers.samediff.DL4JSameDiffMemoryMgr;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

public class SameDiffLayer
extends AbstractLayer<AbstractSameDiffLayer> {
    public static final String INPUT_KEY = "input";
    public static final String MASK_KEY = "mask";
    protected SameDiff sameDiff;
    protected SDVariable outputVar;
    protected ExternalErrorsFunction fn;
    protected String outputKey;
    protected INDArray params;
    protected INDArray gradients;
    protected Map<String, INDArray> paramTable;
    protected Map<String, INDArray> gradTable;

    public SameDiffLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    public Layer clone() {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            if (this.sameDiff == null) {
                this.doInit();
            }
        }
        org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer)this.layerConf();
        bl.validateInput(this.input);
        HashMap<String, INDArray> phMap = new HashMap<String, INDArray>();
        phMap.put(INPUT_KEY, this.input);
        if (this.maskArray != null) {
            phMap.put(MASK_KEY, this.maskArray);
        } else {
            phMap.put(MASK_KEY, ((AbstractSameDiffLayer)this.layerConf()).onesMaskForInput(this.input));
        }
        String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
        String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
        WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
        WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
        boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
        Preconditions.checkState((actScopedOut || wsNameOutput != null ? 1 : 0) != 0, (String)"Activations must have a workspace or must be scoped out");
        DL4JSameDiffMemoryMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
        InferenceSession is = (InferenceSession)this.sameDiff.getSessions().get(Thread.currentThread().getId());
        if (is == null) {
            is = new InferenceSession(this.sameDiff);
            this.sameDiff.getSessions().put(Thread.currentThread().getId(), is);
        }
        is.setMmgr((SessionMemMgr)mmgr);
        Map out = this.sameDiff.output(phMap, new String[]{this.outputKey});
        INDArray result = (INDArray)out.get(this.outputKey);
        if (!actScopedOut && !result.data().getParentWorkspace().getId().equals(wsNameOutput)) {
            result = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
        } else if (actScopedOut && result.isAttached()) {
            result = result.detach();
        }
        this.sameDiff.clearPlaceholders(true);
        this.sameDiff.clearOpInputs();
        return result;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        DefaultGradient g = new DefaultGradient();
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            if (this.sameDiff == null) {
                this.doInit();
            }
            if (!this.sameDiff.hasGradientFunction()) {
                this.sameDiff.createGradFunction(new String[]{INPUT_KEY});
            }
        }
        Map sessionMap = this.sameDiff.getFunction("grad").getSessions();
        if (!sessionMap.containsKey(Thread.currentThread().getId())) {
            sessionMap.put(Thread.currentThread().getId(), new InferenceSession(this.sameDiff.getFunction("grad")));
        }
        String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
        String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
        WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
        WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
        boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
        Preconditions.checkState((actGradScopedOut || wsNameActGrad != null ? 1 : 0) != 0, (String)"Activation gradients must have a workspace or be scoped out");
        DL4JSameDiffMemoryMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
        ((InferenceSession)sessionMap.get(Thread.currentThread().getId())).setMmgr((SessionMemMgr)mmgr);
        org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer)this.layerConf();
        bl.validateInput(this.input);
        HashMap<String, INDArray> phMap = new HashMap<String, INDArray>();
        phMap.put(INPUT_KEY, this.input);
        phMap.put(this.fn.getGradPlaceholderName(), epsilon);
        if (this.maskArray != null) {
            phMap.put(MASK_KEY, this.maskArray);
        } else {
            phMap.put(MASK_KEY, ((AbstractSameDiffLayer)this.layerConf()).onesMaskForInput(this.input));
        }
        ArrayList<String> requiredGrads = new ArrayList<String>(this.paramTable.size() + 1);
        requiredGrads.add(INPUT_KEY);
        requiredGrads.addAll(this.paramTable.keySet());
        Map m = this.sameDiff.calculateGradients(phMap, requiredGrads);
        for (String s : this.paramTable.keySet()) {
            INDArray sdGrad = (INDArray)m.get(s);
            INDArray dl4jGrad = this.gradTable.get(s);
            dl4jGrad.assign(sdGrad);
            g.gradientForVariable().put(s, dl4jGrad);
        }
        INDArray dLdIn = (INDArray)m.get(INPUT_KEY);
        this.sameDiff.clearPlaceholders(true);
        this.sameDiff.clearOpInputs();
        Pair ret = new Pair((Object)g, (Object)workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn));
        return ret;
    }

    @Override
    public INDArray params() {
        return this.params;
    }

    @Override
    public INDArray getParam(String param) {
        return this.paramTable.get(param);
    }

    @Override
    public long numParams() {
        return this.params == null ? 0L : (long)((int)this.params.length());
    }

    @Override
    public void setParam(String key, INDArray val) {
        if (!this.paramTable.containsKey(key)) {
            throw new IllegalArgumentException("Cannot set parameter, invalid/unknown parameter key: " + key);
        }
        INDArray current = this.paramTable.get(key);
        if (!Arrays.equals(current.shape(), val.shape())) {
            throw new IllegalArgumentException("Cannot set parameter \"" + key + "\", invalid shape: parameter array has shape " + Arrays.toString(current.shape()) + ", trying to set parameter of shape " + Arrays.toString(val.shape()));
        }
    }

    @Override
    public void setParams(INDArray params) {
        if (params != null) {
            throw new UnsupportedOperationException("Not supported");
        }
    }

    @Override
    protected void setParams(INDArray params, char order) {
        this.setParams(params);
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        this.params = params;
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.gradients;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        this.gradients = gradients;
        this.gradTable = ((AbstractSameDiffLayer)this.layerConf()).initializer().getGradientsFromFlattened(this.conf(), gradients);
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        if (this.paramTable == null) {
            this.paramTable = paramTable;
        } else {
            for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
                this.setParam(e.getKey(), e.getValue());
            }
        }
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        return this.paramTable;
    }

    protected void doInit() {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer)this.layerConf();
            this.sameDiff = SameDiff.create();
            this.sameDiff.setArrayHolders((ArrayHolder)new SingleThreadArrayHolder(), (ArrayHolder)new SingleThreadArrayHolder(), false);
            Map<String, INDArray> p = this.paramTable();
            long[] inputShape = (long[])this.input.shape().clone();
            inputShape[0] = -1L;
            SDVariable inputVar = this.sameDiff.placeHolder(INPUT_KEY, this.dataType, inputShape);
            Map<String, long[]> paramShapes = ((AbstractSameDiffLayer)this.layerConf()).getLayerParams().getParamShapes();
            LinkedHashMap<String, SDVariable> params = new LinkedHashMap<String, SDVariable>();
            for (String s : paramShapes.keySet()) {
                long[] ps = paramShapes.get(s);
                SDVariable v = this.sameDiff.var(s, this.dataType, ps);
                params.put(s, v);
            }
            long[] maskShape = ArrayUtil.nTimes((long)inputShape.length, (long)-1L);
            SDVariable mask = this.sameDiff.placeHolder(MASK_KEY, this.dataType, maskShape);
            SDVariable layerOutput = bl.defineLayer(this.sameDiff, inputVar, params, mask);
            Preconditions.checkNotNull((Object)layerOutput, (String)"Invalid output: layer output is null");
            this.outputVar = layerOutput;
            for (Map.Entry<String, INDArray> e : p.entrySet()) {
                this.sameDiff.associateArrayWithVariable(e.getValue(), this.sameDiff.getVariable(e.getKey()));
            }
            this.fn = this.sameDiff.f().externalErrors(new SDVariable[]{layerOutput});
            this.fn.outputVariable();
            this.outputKey = this.outputVar.name();
        }
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer)this.layerConf();
        this.maskArray = maskArray;
        this.maskState = currentMaskState;
        return bl.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize);
    }
}

