/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.lang.reflect.Constructor;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.Pair;

public abstract class BaseLayer<LayerConfT extends org.deeplearning4j.nn.conf.layers.BaseLayer>
extends AbstractLayer<LayerConfT> {
    protected INDArray paramsFlattened;
    protected INDArray gradientsFlattened;
    protected Map<String, INDArray> params;
    protected transient Map<String, INDArray> gradientViews;
    protected double score = 0.0;
    protected ConvexOptimizer optimizer;
    protected Gradient gradient;
    protected Solver solver;
    protected Map<String, INDArray> weightNoiseParams = new HashMap<String, INDArray>();

    public BaseLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public LayerConfT layerConf() {
        return (LayerConfT)((org.deeplearning4j.nn.conf.layers.BaseLayer)this.conf.getLayer());
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        Pair<INDArray, INDArray> zAndPreNorm = this.preOutputWithPreNorm(true, true, workspaceMgr);
        INDArray z = (INDArray)zAndPreNorm.getFirst();
        INDArray preNorm = (INDArray)zAndPreNorm.getSecond();
        INDArray delta = (INDArray)((org.deeplearning4j.nn.conf.layers.BaseLayer)this.layerConf()).getActivationFn().backprop(z, epsilon).getFirst();
        if (this.maskArray != null) {
            this.applyMask(delta);
        }
        DefaultGradient ret = new DefaultGradient();
        if (this.hasBias()) {
            INDArray biasGrad = this.gradientViews.get("b");
            delta.sum(biasGrad, new int[]{0});
            ret.gradientForVariable().put("b", biasGrad);
        }
        INDArray W = this.getParamWithNoise("W", true, workspaceMgr);
        INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, delta.dataType(), new long[]{W.size(0), delta.size(0)}, 'f');
        if (this.hasLayerNorm()) {
            INDArray g = this.getParam("g");
            INDArray dldg = this.gradientViews.get("g");
            Nd4j.getExecutioner().exec((CustomOp)new LayerNormBp(preNorm, g, delta, delta, dldg, new int[]{1}));
            ret.gradientForVariable().put("g", dldg);
        }
        epsilonNext = W.mmuli(delta.transpose(), epsilonNext).transpose();
        INDArray weightGrad = this.gradientViews.get("W");
        Nd4j.gemm((INDArray)this.input.castTo(weightGrad.dataType()), (INDArray)delta, (INDArray)weightGrad, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
        ret.gradientForVariable().put("W", weightGrad);
        this.weightNoiseParams.clear();
        epsilonNext = this.backpropDropOutIfPresent(epsilonNext);
        return new Pair((Object)ret, (Object)epsilonNext);
    }

    @Override
    public void fit() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
        if (this.input == null) {
            return;
        }
        INDArray output = this.activate(true, workspaceMgr);
        this.setScoreWithZ(output);
    }

    protected void setScoreWithZ(INDArray z) {
    }

    @Override
    public double score() {
        return this.score;
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public void update(Gradient gradient) {
        for (String paramType : gradient.gradientForVariable().keySet()) {
            this.update(gradient.getGradientFor(paramType), paramType);
        }
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        this.setParam(paramType, this.getParam(paramType).addi(gradient));
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        if (this.optimizer == null) {
            Solver solver = new Solver.Builder().model(this).configure(this.conf()).build();
            this.optimizer = solver.getOptimizer();
        }
        return this.optimizer;
    }

    @Override
    public INDArray params() {
        return this.paramsFlattened;
    }

    @Override
    public INDArray getParam(String param) {
        return this.params.get(param);
    }

    @Override
    public void setParam(String key, INDArray val) {
        if (this.params.containsKey(key)) {
            this.params.get(key).assign(val);
        } else {
            this.params.put(key, val);
        }
    }

    @Override
    public void setParams(INDArray params) {
        if (params == this.paramsFlattened) {
            return;
        }
        this.setParams(params, 'f');
    }

    @Override
    protected void setParams(INDArray params, char order) {
        List<String> parameterList = this.conf.variables();
        int length = 0;
        for (String s : parameterList) {
            length = (int)((long)length + this.getParam(s).length());
        }
        if (params.length() != (long)length) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + length + ", got params of length " + params.length() + " - " + this.layerId());
        }
        int idx = 0;
        Set<String> paramKeySet = this.params.keySet();
        for (String s : paramKeySet) {
            INDArray param = this.getParam(s);
            INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((long)0L), NDArrayIndex.interval((long)idx, (long)((long)idx + param.length()))});
            if (param.length() != get.length()) {
                throw new IllegalStateException("Parameter " + s + " should have been of length " + param.length() + " but was " + get.length() + " - " + this.layerId());
            }
            param.assign(get.reshape(order, param.shape()));
            idx = (int)((long)idx + param.length());
        }
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        if (this.params != null && params.length() != this.numParams()) {
            throw new IllegalArgumentException("Invalid input: expect params of length " + this.numParams() + ", got params of length " + params.length() + " - " + this.layerId());
        }
        this.paramsFlattened = params;
    }

    @Override
    public INDArray getGradientsViewArray() {
        return this.gradientsFlattened;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        if (this.params != null && gradients.length() != this.numParams()) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + this.numParams(true) + ", got array of length " + gradients.length() + " - " + this.layerId());
        }
        this.gradientsFlattened = gradients;
        this.gradientViews = this.conf.getLayer().initializer().getGradientsFromFlattened(this.conf, gradients);
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        this.params = paramTable;
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.paramTable(false);
    }

    @Override
    public Map<String, INDArray> paramTable(boolean backpropParamsOnly) {
        return this.params;
    }

    protected INDArray getParamWithNoise(String param, boolean training, LayerWorkspaceMgr workspaceMgr) {
        INDArray p;
        if (((org.deeplearning4j.nn.conf.layers.BaseLayer)this.layerConf()).getWeightNoise() != null) {
            if (training && this.weightNoiseParams.size() > 0 && this.weightNoiseParams.containsKey(param)) {
                return this.weightNoiseParams.get(param);
            }
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                p = ((org.deeplearning4j.nn.conf.layers.BaseLayer)this.layerConf()).getWeightNoise().getParameter(this, param, this.getIterationCount(), this.getEpochCount(), training, workspaceMgr);
            }
            if (training) {
                this.weightNoiseParams.put(param, p);
            }
        } else {
            return this.getParam(param);
        }
        return p;
    }

    protected INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) {
        return (INDArray)this.preOutputWithPreNorm(training, false, workspaceMgr).getFirst();
    }

    protected Pair<INDArray, INDArray> preOutputWithPreNorm(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(forBackprop);
        this.applyDropOutIfNecessary(training, workspaceMgr);
        INDArray W = this.getParamWithNoise("W", training, workspaceMgr);
        INDArray b = this.getParamWithNoise("b", training, workspaceMgr);
        INDArray g = this.hasLayerNorm() ? this.getParam("g") : null;
        INDArray input = this.input.castTo(this.dataType);
        if (input.rank() != 2 || input.columns() != W.rows()) {
            if (input.rank() != 2) {
                throw new DL4JInvalidInputException("Input that is not a matrix; expected matrix (rank 2), got rank " + input.rank() + " array with shape " + Arrays.toString(input.shape()) + ". Missing preprocessor or wrong input type? " + this.layerId());
            }
            throw new DL4JInvalidInputException("Input size (" + input.columns() + " columns; shape = " + Arrays.toString(input.shape()) + ") is invalid: does not match layer input size (layer # inputs = " + W.size(0) + ") " + this.layerId());
        }
        INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, W.dataType(), new long[]{input.size(0), W.size(1)});
        input.castTo(ret.dataType()).mmuli(W, ret);
        INDArray preNorm = ret;
        if (this.hasLayerNorm()) {
            preNorm = forBackprop ? ret.dup(ret.ordering()) : ret;
            Nd4j.getExecutioner().exec((CustomOp)new LayerNorm(preNorm, g, ret, new int[]{1}));
        }
        if (this.hasBias()) {
            ret.addiRowVector(b);
        }
        if (this.maskArray != null) {
            this.applyMask(ret);
        }
        return new Pair((Object)ret, (Object)preNorm);
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        INDArray z = this.preOutput(training, workspaceMgr);
        INDArray ret = ((org.deeplearning4j.nn.conf.layers.BaseLayer)this.layerConf()).getActivationFn().getActivation(z, training);
        if (this.maskArray != null) {
            this.applyMask(ret);
        }
        return ret;
    }

    @Override
    public double calcRegularizationScore(boolean backpropParamsOnly) {
        double scoreSum = 0.0;
        for (Map.Entry<String, INDArray> e : this.paramTable().entrySet()) {
            List<Regularization> l = ((org.deeplearning4j.nn.conf.layers.BaseLayer)this.layerConf()).getRegularizationByParam(e.getKey());
            if (l == null || l.isEmpty()) continue;
            for (Regularization r : l) {
                scoreSum += r.score(e.getValue(), this.getIterationCount(), this.getEpochCount());
            }
        }
        return scoreSum;
    }

    public Layer clone() {
        Layer layer = null;
        try {
            Constructor<?> c = this.getClass().getConstructor(NeuralNetConfiguration.class);
            layer = (Layer)c.newInstance(this.conf);
            LinkedHashMap<String, INDArray> linkedTable = new LinkedHashMap<String, INDArray>();
            for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
                linkedTable.put(entry.getKey(), entry.getValue().dup());
            }
            layer.setParamTable(linkedTable);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }

    @Override
    public long numParams() {
        int ret = 0;
        for (INDArray val : this.params.values()) {
            ret = (int)((long)ret + val.length());
        }
        return ret;
    }

    @Override
    public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) {
        if (input != null) {
            this.setInput(input, workspaceMgr);
            this.applyDropOutIfNecessary(true, workspaceMgr);
        }
        if (this.solver == null) {
            this.solver = new Solver.Builder().model(this).configure(this.conf()).listeners(this.getListeners()).build();
        }
        this.optimizer = this.solver.getOptimizer();
        this.solver.optimize(workspaceMgr);
    }

    @Override
    public String toString() {
        return this.getClass().getName() + "{conf=" + this.conf + ", score=" + this.score + ", optimizer=" + this.optimizer + ", listeners=" + this.trainingListeners + '}';
    }

    @Override
    public void clear() {
        super.clear();
        this.weightNoiseParams.clear();
    }

    @Override
    public void clearNoiseWeightParams() {
        this.weightNoiseParams.clear();
    }

    public boolean hasBias() {
        return true;
    }

    public boolean hasLayerNorm() {
        return false;
    }
}

