/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public abstract class BaseLayer<LayerConfT extends Layer>
implements org.deeplearning4j.nn.api.Layer {
    protected INDArray input;
    protected INDArray paramsFlattened;
    protected INDArray gradientsFlattened;
    protected Map<String, INDArray> params;
    protected transient Map<String, INDArray> gradientViews;
    protected NeuralNetConfiguration conf;
    protected INDArray dropoutMask;
    protected boolean dropoutApplied = false;
    protected double score = 0.0;
    protected ConvexOptimizer optimizer;
    protected Gradient gradient;
    protected Collection<IterationListener> iterationListeners = new ArrayList<IterationListener>();
    protected int index = 0;
    protected INDArray maskArray;
    protected Solver solver;

    public BaseLayer(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    public BaseLayer(NeuralNetConfiguration conf, INDArray input) {
        this.input = input;
        this.conf = conf;
    }

    protected LayerConfT layerConf() {
        return (LayerConfT)this.conf.getLayer();
    }

    public INDArray getInput() {
        return this.input;
    }

    @Override
    public void setInput(INDArray input) {
        this.input = input;
        this.dropoutApplied = false;
    }

    @Override
    public int getIndex() {
        return this.index;
    }

    @Override
    public void setIndex(int index) {
        this.index = index;
    }

    @Override
    public Collection<IterationListener> getListeners() {
        return this.iterationListeners;
    }

    @Override
    public void setListeners(Collection<IterationListener> listeners) {
        this.iterationListeners = listeners != null ? listeners : new ArrayList();
    }

    @Override
    public void setListeners(IterationListener ... listeners) {
        this.iterationListeners = new ArrayList<IterationListener>();
        for (IterationListener l : listeners) {
            this.iterationListeners.add(l);
        }
    }

    @Override
    public Gradient error(INDArray errorSignal) {
        INDArray W = this.getParam("W");
        DefaultGradient nextLayerGradient = new DefaultGradient();
        INDArray wErrorSignal = errorSignal.mmul(W.transpose());
        nextLayerGradient.gradientForVariable().put("W", wErrorSignal);
        return nextLayerGradient;
    }

    @Override
    public INDArray derivativeActivation(INDArray input) {
        INDArray deriv = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf().getLayer().getActivationFunction(), input).derivative());
        return deriv;
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray activation) {
        DefaultGradient ret = new DefaultGradient();
        INDArray weightErrorSignal = layerError.getGradientFor("W");
        INDArray weightError = weightErrorSignal.transpose().mmul(activation).transpose();
        ret.gradientForVariable().put("W", weightError);
        INDArray biasGradient = weightError.mean(new int[]{0});
        ret.gradientForVariable().put("b", biasGradient);
        return ret;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        INDArray z = this.preOutput(true);
        INDArray activationDerivative = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf().getLayer().getActivationFunction(), z).derivative());
        INDArray delta = epsilon.muli(activationDerivative);
        if (this.maskArray != null) {
            delta.muliColumnVector(this.maskArray);
        }
        DefaultGradient ret = new DefaultGradient();
        INDArray weightGrad = this.gradientViews.get("W");
        Nd4j.gemm((INDArray)this.input, (INDArray)delta, (INDArray)weightGrad, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
        INDArray biasGrad = this.gradientViews.get("b");
        biasGrad.assign(delta.sum(new int[]{0}));
        ret.gradientForVariable().put("W", weightGrad);
        ret.gradientForVariable().put("b", biasGrad);
        INDArray epsilonNext = this.params.get("W").mmul(delta.transpose()).transpose();
        return new Pair<Gradient, INDArray>(ret, epsilonNext);
    }

    @Override
    public void fit() {
        this.fit(this.input);
    }

    @Override
    public void computeGradientAndScore() {
        if (this.input == null) {
            return;
        }
        INDArray output = this.activate(true);
        this.setScoreWithZ(output);
    }

    protected void setScoreWithZ(INDArray z) {
    }

    @Override
    public INDArray preOutput(INDArray x, Layer.TrainingMode training) {
        return this.preOutput(x, training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray activate(Layer.TrainingMode training) {
        return this.activate(training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray activate(INDArray input, Layer.TrainingMode training) {
        return this.activate(input, training == Layer.TrainingMode.TRAIN);
    }

    @Override
    public double score() {
        return this.score;
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public void iterate(INDArray input) {
        this.setInput(input.dup());
        this.applyDropOutIfNecessary(true);
        Gradient gradient = this.gradient();
        for (String paramType : gradient.gradientForVariable().keySet()) {
            this.update(gradient.getGradientFor(paramType), paramType);
        }
    }

    @Override
    public void update(Gradient gradient) {
        for (String paramType : gradient.gradientForVariable().keySet()) {
            this.update(gradient.getGradientFor(paramType), paramType);
        }
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        this.setParam(paramType, this.getParam(paramType).addi(gradient));
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        if (this.optimizer == null) {
            Solver solver = new Solver.Builder().model(this).configure(this.conf()).build();
            this.optimizer = solver.getOptimizer();
        }
        return this.optimizer;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    @Override
    public INDArray params() {
        return Nd4j.toFlattened((char)'f', this.params.values());
    }

    @Override
    public INDArray getParam(String param) {
        return this.params.get(param);
    }

    @Override
    public void setParam(String key, INDArray val) {
        if (this.params.containsKey(key)) {
            this.params.get(key).assign(val);
        } else {
            this.params.put(key, val);
        }
    }

    @Override
    public void setParams(INDArray params) {
        if (params == this.paramsFlattened) {
            return;
        }
        this.setParams(params, 'f');
    }

    protected void setParams(INDArray params, char order) {
        List<String> parameterList = this.conf.variables();
        int length = 0;
        for (String s : parameterList) {
            length += this.getParam(s).length();
        }
        if (params.length() != length) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + length);
        }
        int idx = 0;
        Set<String> paramKeySet = this.params.keySet();
        for (String s : paramKeySet) {
            INDArray param = this.getParam(s);
            INDArray get = params.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)idx, (int)(idx + param.length()))});
            if (param.length() != get.length()) {
                throw new IllegalStateException("Parameter " + s + " should have been of length " + param.length() + " but was " + get.length());
            }
            param.assign(get.reshape(order, param.shape()));
            idx += param.length();
        }
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        if (this.params != null && params.length() != this.numParams()) {
            throw new IllegalArgumentException("Invalid input: expect params of length " + this.numParams() + ", got params of length " + params.length());
        }
        this.paramsFlattened = params;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        if (this.params != null && gradients.length() != this.numParams(true)) {
            throw new IllegalArgumentException("Invalid input: expect gradients array of length " + this.numParams(true) + ", got params of length " + gradients.length());
        }
        this.gradientsFlattened = gradients;
        this.gradientViews = this.conf.getLayer().initializer().getGradientsFromFlattened(this.conf, gradients);
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        this.params = paramTable;
    }

    @Override
    public void initParams() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.params;
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        if (x == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        this.setInput(x);
        return this.preOutput(training);
    }

    public INDArray preOutput(boolean training) {
        this.applyDropOutIfNecessary(training);
        INDArray b = this.getParam("b");
        INDArray W = this.getParam("W");
        if (this.input.rank() != 2 || this.input.columns() != W.rows()) {
            if (this.input.rank() != 2) {
                throw new DL4JInvalidInputException("Input that is not a matrix; expected matrix (rank 2), got rank " + this.input.rank() + " array with shape " + Arrays.toString(this.input.shape()));
            }
            throw new DL4JInvalidInputException("Input size (" + this.input.columns() + " columns; shape = " + Arrays.toString(this.input.shape()) + ") is invalid: does not match layer input size (layer # inputs = " + W.size(0) + ")");
        }
        if (this.conf.isUseDropConnect() && training && this.conf.getLayer().getDropOut() > 0.0) {
            W = Dropout.applyDropConnect(this, "W");
        }
        INDArray ret = this.input.mmul(W).addiRowVector(b);
        if (this.maskArray != null) {
            ret.muliColumnVector(this.maskArray);
        }
        return ret;
    }

    @Override
    public INDArray activate(boolean training) {
        INDArray z = this.preOutput(training);
        INDArray ret = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), z, this.conf.getExtraArgs()));
        if (this.maskArray != null) {
            ret.muliColumnVector(this.maskArray);
        }
        return ret;
    }

    @Override
    public INDArray activate(INDArray input) {
        this.setInput(input);
        return this.activate(true);
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        this.setInput(input);
        return this.activate(training);
    }

    @Override
    public INDArray activate() {
        return this.activate(false);
    }

    @Override
    public INDArray preOutput(INDArray x) {
        return this.preOutput(x, true);
    }

    @Override
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL2() <= 0.0) {
            return 0.0;
        }
        double l2Norm = this.getParam("W").norm2Number().doubleValue();
        return 0.5 * this.conf.getLayer().getL2() * l2Norm * l2Norm;
    }

    @Override
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL1() <= 0.0) {
            return 0.0;
        }
        return this.conf.getLayer().getL1() * this.getParam("W").norm1Number().doubleValue();
    }

    @Override
    public int batchSize() {
        return this.input.size(0);
    }

    @Override
    public INDArray activationMean() {
        INDArray b = this.getParam("b");
        INDArray W = this.getParam("W");
        return this.input().mmul(W).addiRowVector(b);
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override
    public void clear() {
        if (this.input != null) {
            this.input.data().destroy();
            this.input = null;
        }
    }

    protected void applyDropOutIfNecessary(boolean training) {
        if (this.conf.getLayer().getDropOut() > 0.0 && !this.conf.isUseDropConnect() && training && !this.dropoutApplied) {
            Dropout.applyDropout(this.input, this.conf.getLayer().getDropOut());
            this.dropoutApplied = true;
        }
    }

    @Override
    public void merge(org.deeplearning4j.nn.api.Layer l, int batchSize) {
        this.setParams(this.params().addi(l.params().divi((Number)batchSize)));
        this.computeGradientAndScore();
    }

    @Override
    public org.deeplearning4j.nn.api.Layer clone() {
        org.deeplearning4j.nn.api.Layer layer = null;
        try {
            Constructor<?> c = this.getClass().getConstructor(NeuralNetConfiguration.class);
            layer = (org.deeplearning4j.nn.api.Layer)c.newInstance(this.conf);
            LinkedHashMap<String, INDArray> linkedTable = new LinkedHashMap<String, INDArray>();
            for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
                linkedTable.put(entry.getKey(), entry.getValue().dup());
            }
            layer.setParamTable(linkedTable);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.FEED_FORWARD;
    }

    @Override
    public int numParams() {
        int ret = 0;
        for (INDArray val : this.params.values()) {
            ret += val.length();
        }
        return ret;
    }

    @Override
    public int numParams(boolean backwards) {
        if (backwards) {
            int ret = 0;
            for (Map.Entry<String, INDArray> entry : this.params.entrySet()) {
                if (this instanceof BasePretrainNetwork && "bB".equals(entry.getKey())) continue;
                ret += entry.getValue().length();
            }
            return ret;
        }
        return this.numParams();
    }

    @Override
    public void fit(INDArray input) {
        if (input != null) {
            this.setInput(input.dup());
            this.applyDropOutIfNecessary(true);
        }
        if (this.solver == null) {
            this.solver = new Solver.Builder().model(this).configure(this.conf()).listeners(this.getListeners()).build();
            Updater updater = this.solver.getOptimizer().getUpdater();
            int updaterStateSize = updater.stateSizeForLayer(this);
            if (updaterStateSize > 0) {
                updater.setStateViewArray(this, Nd4j.createUninitialized((int[])new int[]{1, updaterStateSize}, (char)Nd4j.order().charValue()), true);
            }
        }
        this.optimizer = this.solver.getOptimizer();
        this.solver.optimize();
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    @Override
    public INDArray input() {
        return this.input;
    }

    @Override
    public void validateInput() {
    }

    protected Gradient createGradient(INDArray ... gradients) {
        DefaultGradient ret = new DefaultGradient();
        if (gradients.length != this.conf.variables().size()) {
            throw new IllegalArgumentException("Unable to create gradients...not equal to number of parameters");
        }
        for (int i = 0; i < gradients.length; ++i) {
            INDArray paramI = this.getParam(this.conf.variables().get(i));
            if (!Arrays.equals(paramI.shape(), gradients[i].shape())) {
                throw new IllegalArgumentException("Gradient at index " + i + " had wrong gradient size of " + Arrays.toString(gradients[i].shape()) + " when should have been " + Arrays.toString(paramI.shape()));
            }
            ret.gradientForVariable().put(this.conf.variables().get(i), gradients[i]);
        }
        return ret;
    }

    public String toString() {
        return this.getClass().getName() + "{conf=" + this.conf + ", dropoutMask=" + this.dropoutMask + ", score=" + this.score + ", optimizer=" + this.optimizer + ", listeners=" + this.iterationListeners + '}';
    }

    @Override
    public org.deeplearning4j.nn.api.Layer transpose() {
        org.deeplearning4j.nn.api.Layer layer;
        if (!(this.conf.getLayer() instanceof FeedForwardLayer)) {
            throw new UnsupportedOperationException("unsupported layer type: " + this.conf.getLayer().getClass().getName());
        }
        INDArray w = this.getParam("W");
        INDArray b = this.getParam("b");
        INDArray vb = this.getParam("bB");
        try {
            INDArray newB;
            NeuralNetConfiguration clone = this.conf.clone();
            FeedForwardLayer clonedLayerConf = (FeedForwardLayer)clone.getLayer();
            int nIn = clonedLayerConf.getNOut();
            int nOut = clonedLayerConf.getNIn();
            clonedLayerConf.setNIn(nIn);
            clonedLayerConf.setNOut(nOut);
            INDArray newVB = null;
            if (vb != null) {
                newB = vb.dup();
                newVB = b.dup();
            } else {
                newB = Nd4j.create((int)1, (int)nOut);
            }
            INDArray paramsView = Nd4j.create((int)1, (int)(w.length() + nOut));
            layer = clone.getLayer().instantiate(clone, this.iterationListeners, this.index, paramsView, true);
            layer.setParam("W", w.transpose().dup());
            layer.setParam("b", newB);
            if (vb != null) {
                layer.setParam("bB", newVB);
            }
        }
        catch (Exception e) {
            throw new RuntimeException("unable to construct transposed layer", e);
        }
        return layer;
    }

    @Override
    public void accumulateScore(double accum) {
        this.score += accum;
    }

    @Override
    public void setInputMiniBatchSize(int size) {
    }

    @Override
    public int getInputMiniBatchSize() {
        return this.input.size(0);
    }

    @Override
    public void applyLearningRateScoreDecay() {
        for (Map.Entry<String, Double> lrPair : this.conf.getLearningRateByParam().entrySet()) {
            this.conf.setLearningRateByParam(lrPair.getKey(), lrPair.getValue() * (this.conf.getLrPolicyDecayRate() + Nd4j.EPS_THRESHOLD));
        }
    }

    @Override
    public void setMaskArray(INDArray maskArray) {
        this.maskArray = maskArray;
    }

    @Override
    public INDArray getMaskArray() {
        return this.maskArray;
    }
}

