/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.samediff;

import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.params.SameDiffParamInitializer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class SameDiffGraphVertex
extends BaseGraphVertex {
    protected SameDiffVertex config;
    protected SameDiff sameDiff;
    protected SDVariable outputVar;
    protected ExternalErrorsFunction fn;
    protected String outputKey;
    protected Map<String, SDVariable> inputVars;
    protected INDArray[] maskArrays;
    protected INDArray params;
    protected INDArray gradients;
    protected Map<String, INDArray> paramTable;
    protected Map<String, INDArray> gradTable;
    private MaskState currentMaskState;
    private int minibatchSize;

    public SameDiffGraphVertex(SameDiffVertex config, ComputationGraph graph, String name, int vertexIndex, INDArray paramsView, boolean initParams, DataType dataType) {
        super(graph, name, vertexIndex, null, null, dataType);
        this.config = config;
        SDVertexParams vp = config.getVertexParams();
        this.paramTable = SameDiffParamInitializer.getInstance().subsetAndReshape(vp.getParameterKeys(), vp.getParamShapes(), paramsView, null, config);
        if (initParams) {
            config.initializeParameters(this.paramTable);
        }
        this.params = paramsView;
    }

    @Override
    public String toString() {
        return null;
    }

    @Override
    public boolean hasLayer() {
        return false;
    }

    @Override
    public Layer getLayer() {
        return null;
    }

    @Override
    public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
        if (this.sameDiff == null) {
            this.doInit();
        }
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            this.config.validateInput(this.inputs);
            for (int i = 0; i < this.inputs.length; ++i) {
                String name = this.config.getVertexParams().getInputs().get(i);
                String maskName = name + "_mask";
                this.sameDiff.associateArrayWithVariable(this.inputs[i].dup(), this.sameDiff.getVariable(name));
                if (this.maskArrays != null && this.maskArrays[i] != null) {
                    this.sameDiff.associateArrayWithVariable(this.maskArrays[i].dup(), maskName);
                    continue;
                }
                this.sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(this.dataType, this.inputs[i].shape()), maskName);
            }
            if (this.paramTable != null && this.paramTable.size() > 0) {
                for (String s : this.paramTable.keySet()) {
                    this.sameDiff.associateArrayWithVariable(this.paramTable.get(s), s);
                }
            }
            Map out = this.sameDiff.exec(null, new String[]{this.outputKey});
            INDArray result = (INDArray)out.get(this.outputKey);
            INDArray iNDArray = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
            return iNDArray;
        }
    }

    @Override
    public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) {
        INDArray[] dLdIns;
        DefaultGradient g = new DefaultGradient();
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            String name;
            this.config.validateInput(this.inputs);
            for (int i = 0; i < this.inputs.length; ++i) {
                name = this.config.getVertexParams().getInputs().get(i);
                String maskName = name + "_mask";
                this.sameDiff.associateArrayWithVariable(this.inputs[i].dup(), this.sameDiff.getVariable(name));
                if (this.maskArrays != null && this.maskArrays[i] != null) {
                    this.sameDiff.associateArrayWithVariable(this.maskArrays[i].dup(), maskName);
                    continue;
                }
                this.sameDiff.associateArrayWithVariable(SameDiffGraphVertex.createMask(this.dataType, this.inputs[i].shape()), maskName);
            }
            this.fn.updateVariable(this.outputVar.getVarName(), this.epsilon.dup());
            for (String s : this.paramTable.keySet()) {
                this.sameDiff.associateArrayWithVariable(this.paramTable.get(s), s);
            }
            this.sameDiff.execBackwards(null);
            for (String s : this.paramTable.keySet()) {
                INDArray sdGrad = this.sameDiff.grad(s).getArr();
                INDArray dl4jGrad = this.gradTable.get(s);
                dl4jGrad.assign(sdGrad);
                g.gradientForVariable().put(s, dl4jGrad);
            }
            dLdIns = new INDArray[this.inputs.length];
            for (int i = 0; i < this.inputs.length; ++i) {
                name = this.config.getVertexParams().getInputs().get(i);
                dLdIns[i] = this.sameDiff.grad(name).getArr();
            }
        }
        for (int i = 0; i < dLdIns.length; ++i) {
            dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
        }
        return new Pair((Object)g, (Object)dLdIns);
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
        SDVertexParams vp = this.config.getVertexParams();
        this.gradTable = SameDiffParamInitializer.getInstance().subsetAndReshape(vp.getParameterKeys(), vp.getParamShapes(), backpropGradientsViewArray, null, this.config);
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) {
        this.maskArrays = maskArrays;
        this.currentMaskState = currentMaskState;
        return this.config.feedForwardMaskArrays(maskArrays, currentMaskState, minibatchSize);
    }

    protected void doInit() {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            Object inputShape;
            this.sameDiff = SameDiff.create();
            this.inputVars = new LinkedHashMap<String, SDVariable>();
            LinkedHashMap<String, SDVariable> maskVars = new LinkedHashMap<String, SDVariable>();
            int i = 0;
            for (String s : this.config.getVertexParams().getInputs()) {
                inputShape = (long[])this.inputs[i++].shape().clone();
                SDVariable inputVar = this.sameDiff.var(s, this.dataType, (long[])inputShape);
                this.inputVars.put(s, inputVar);
                SDVariable maskVar = this.sameDiff.constant(s + "_mask", SameDiffGraphVertex.createMask(this.dataType, (long[])inputShape));
                maskVars.put(s, maskVar);
            }
            Map<String, long[]> paramShapes = this.config.getVertexParams().getParamShapes();
            LinkedHashMap<String, SDVariable> params = new LinkedHashMap<String, SDVariable>();
            inputShape = paramShapes.keySet().iterator();
            while (inputShape.hasNext()) {
                String s = (String)inputShape.next();
                long[] ps = paramShapes.get(s);
                SDVariable v = this.sameDiff.var(s, this.dataType, ps);
                params.put(s, v);
            }
            SDVariable layerOutput = this.config.defineVertex(this.sameDiff, this.inputVars, params, maskVars);
            Preconditions.checkNotNull((Object)layerOutput, (String)"Invalid output: layer output is null");
            this.outputVar = layerOutput;
            for (Map.Entry<String, INDArray> e : this.paramTable.entrySet()) {
                this.sameDiff.associateArrayWithVariable(e.getValue(), this.sameDiff.getVariable(e.getKey()));
            }
            this.fn = this.sameDiff.f().externalErrors(new SDVariable[]{layerOutput});
            this.fn.outputVariable();
            this.outputKey = this.outputVar.getVarName();
        }
    }

    @Override
    public void clearVertex() {
        this.clear();
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropOnly) {
        return this.paramTable;
    }

    @Override
    public TrainingConfig getConfig() {
        return this.config;
    }

    @Override
    public INDArray params() {
        return this.params;
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.gradients;
    }

    static INDArray createMask(DataType dataType, long[] shape) {
        switch (shape.length) {
            case 2: {
                return Nd4j.ones((DataType)dataType, (long[])new long[]{shape[0], 1L});
            }
            case 3: {
                return Nd4j.ones((DataType)dataType, (long[])new long[]{shape[0], shape[2]});
            }
            case 4: {
                return Nd4j.ones((DataType)dataType, (long[])new long[]{shape[0], 1L, 1L, 1L});
            }
        }
        Preconditions.throwEx((String)"Can not create all-ones-mask for given input shape %s.", (Object[])new Object[]{Arrays.toString(shape)});
        return null;
    }
}

