/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.updater;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.ops.transforms.Transforms;

public class UpdaterBlock {
    private int paramOffsetStart;
    private int paramOffsetEnd;
    private int updaterViewOffsetStart;
    private int updaterViewOffsetEnd;
    private List<ParamState> layersAndVariablesInBlock = new ArrayList<ParamState>();
    private INDArray updaterView;
    private INDArray gradientView;
    private boolean updaterViewRequiresInitialization;
    private GradientUpdater gradientUpdater;

    public UpdaterBlock(int paramOffsetStart, int paramOffsetEnd, int updaterViewOffsetStart, int updaterViewOffsetEnd, List<ParamState> layersAndVariablesInBlock) {
        this.paramOffsetStart = paramOffsetStart;
        this.paramOffsetEnd = paramOffsetEnd;
        this.updaterViewOffsetStart = updaterViewOffsetStart;
        this.updaterViewOffsetEnd = updaterViewOffsetEnd;
        this.layersAndVariablesInBlock = layersAndVariablesInBlock;
    }

    public void init() {
        if (this.gradientUpdater == null) {
            ParamState varState = this.layersAndVariablesInBlock.get(0);
            String varName = varState.getParamName();
            this.gradientUpdater = varState.getLayer().conf().getLayer().getIUpdaterByParam(varName).instantiate(this.updaterView, this.updaterViewRequiresInitialization);
        }
    }

    public boolean isPretrainUpdaterBlock() {
        ParamState vs = this.layersAndVariablesInBlock.get(0);
        return vs.getLayer().conf().getLayer().isPretrainParam(vs.getParamName());
    }

    public boolean skipDueToPretrainConfig() {
        if (!this.isPretrainUpdaterBlock()) {
            return false;
        }
        ParamState vs = this.layersAndVariablesInBlock.get(0);
        return !vs.getLayer().conf().isPretrain();
    }

    public GradientUpdater getGradientUpdater() {
        if (this.gradientUpdater == null) {
            this.init();
        }
        return this.gradientUpdater;
    }

    public void update(int iteration) {
        this.update(iteration, false, this.gradientView, null);
    }

    public void updateExternalGradient(int iteration, INDArray fullNetworkGradientView, INDArray fullNetworkParamsArray) {
        this.update(iteration, true, fullNetworkGradientView, fullNetworkParamsArray);
    }

    private void update(int iteration, boolean externalGradient, INDArray fullNetworkGradientView, INDArray fullNetworkParamsArray) {
        if (this.gradientUpdater == null) {
            this.init();
        }
        INDArray blockGradViewArray = externalGradient ? fullNetworkGradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)this.paramOffsetStart, (int)this.paramOffsetEnd)}) : this.gradientView;
        Layer l0 = this.layersAndVariablesInBlock.get(0).getLayer();
        if (!(l0.conf().getLayer() instanceof BaseLayer)) {
            return;
        }
        BaseLayer baseLayer = (BaseLayer)l0.conf().getLayer();
        LearningRatePolicy lrPolicy = l0.conf().getLearningRatePolicy();
        if (lrPolicy != LearningRatePolicy.None || baseLayer.getIUpdater() instanceof Nesterovs) {
            this.applyLrDecayPolicy(lrPolicy, iteration);
        }
        this.gradientUpdater.applyUpdater(blockGradViewArray, iteration);
        for (ParamState p : this.layersAndVariablesInBlock) {
            INDArray gradView;
            INDArray paramView;
            if (externalGradient) {
                paramView = fullNetworkParamsArray.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)p.getParamOffsetStart(), (int)p.getParamOffsetEnd())});
                gradView = fullNetworkGradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)p.getParamOffsetStart(), (int)p.getParamOffsetEnd())});
            } else {
                paramView = p.getParamView();
                gradView = p.getGradView();
            }
            this.postApply(p.getLayer(), p.getParamName(), gradView, paramView);
        }
    }

    public void postApply(Layer layer, String paramName, INDArray gradientView, INDArray paramsView) {
        NeuralNetConfiguration conf = layer.conf();
        double l2 = conf.getL2ByParam(paramName);
        if (conf.isUseRegularization() && l2 > 0.0) {
            int length = gradientView.length();
            Nd4j.getBlasWrapper().level1().axpy(length, l2, paramsView, gradientView);
        }
        if (conf.isUseRegularization() && conf.getL1ByParam(paramName) > 0.0) {
            gradientView.addi(Transforms.sign((INDArray)paramsView, (boolean)true).muli((Number)conf.getL1ByParam(paramName)));
        }
    }

    public void applyLrDecayPolicy(LearningRatePolicy decay, int iteration) {
        double newLr;
        Layer layer = this.layersAndVariablesInBlock.get(0).getLayer();
        String variable = this.layersAndVariablesInBlock.get(0).getParamName();
        NeuralNetConfiguration conf = layer.conf();
        double decayRate = layer.conf().getLrPolicyDecayRate();
        double lr = conf.getLearningRateByParam(variable);
        if (!(conf.getLayer() instanceof BaseLayer)) {
            return;
        }
        BaseLayer baseLayer = (BaseLayer)conf.getLayer();
        switch (decay) {
            case Exponential: {
                newLr = lr * Math.pow(decayRate, iteration);
                break;
            }
            case Inverse: {
                newLr = lr / Math.pow(1.0 + decayRate * (double)iteration, conf.getLrPolicyPower());
                break;
            }
            case Step: {
                newLr = lr * Math.pow(decayRate, Math.floor((double)iteration / conf.getLrPolicySteps()));
                break;
            }
            case TorchStep: {
                if (iteration > 1 && conf.getLrPolicySteps() % (double)iteration == 0.0) {
                    newLr = lr * decayRate;
                    break;
                }
                newLr = lr;
                break;
            }
            case Poly: {
                newLr = lr * Math.pow(1.0 - (double)iteration / (double)conf.getNumIterations(), conf.getLrPolicyPower());
                break;
            }
            case Sigmoid: {
                newLr = lr / (1.0 + Math.exp(-decayRate * ((double)iteration - conf.getLrPolicySteps())));
                break;
            }
            case Schedule: {
                if (baseLayer.getLearningRateSchedule().containsKey(iteration)) {
                    newLr = baseLayer.getLearningRateSchedule().get(iteration);
                    break;
                }
                newLr = lr;
                break;
            }
            case None: 
            case Score: {
                newLr = lr;
                break;
            }
            default: {
                throw new RuntimeException("Unknown Learning rate decay value: " + (Object)((Object)decay));
            }
        }
        double newMomentum = 0.0;
        if (baseLayer.getIUpdater() instanceof Nesterovs) {
            newMomentum = baseLayer.getMomentumSchedule() != null && baseLayer.getMomentumSchedule().containsKey(iteration) ? baseLayer.getMomentumSchedule().get(iteration).doubleValue() : baseLayer.getMomentum();
        }
        for (ParamState vs : this.layersAndVariablesInBlock) {
            vs.getLayer().conf().setLearningRateByParam(vs.getParamName(), newLr);
            if (!(((BaseLayer)layer.conf().getLayer()).getIUpdater() instanceof Nesterovs)) continue;
            ((BaseLayer)vs.getLayer().conf().getLayer()).setMomentum(newMomentum);
        }
        this.gradientUpdater.getConfig().applySchedules(iteration, newLr);
    }

    public int getParamOffsetStart() {
        return this.paramOffsetStart;
    }

    public int getParamOffsetEnd() {
        return this.paramOffsetEnd;
    }

    public int getUpdaterViewOffsetStart() {
        return this.updaterViewOffsetStart;
    }

    public int getUpdaterViewOffsetEnd() {
        return this.updaterViewOffsetEnd;
    }

    public List<ParamState> getLayersAndVariablesInBlock() {
        return this.layersAndVariablesInBlock;
    }

    public INDArray getUpdaterView() {
        return this.updaterView;
    }

    public INDArray getGradientView() {
        return this.gradientView;
    }

    public boolean isUpdaterViewRequiresInitialization() {
        return this.updaterViewRequiresInitialization;
    }

    public void setParamOffsetStart(int paramOffsetStart) {
        this.paramOffsetStart = paramOffsetStart;
    }

    public void setParamOffsetEnd(int paramOffsetEnd) {
        this.paramOffsetEnd = paramOffsetEnd;
    }

    public void setUpdaterViewOffsetStart(int updaterViewOffsetStart) {
        this.updaterViewOffsetStart = updaterViewOffsetStart;
    }

    public void setUpdaterViewOffsetEnd(int updaterViewOffsetEnd) {
        this.updaterViewOffsetEnd = updaterViewOffsetEnd;
    }

    public void setLayersAndVariablesInBlock(List<ParamState> layersAndVariablesInBlock) {
        this.layersAndVariablesInBlock = layersAndVariablesInBlock;
    }

    public void setUpdaterView(INDArray updaterView) {
        this.updaterView = updaterView;
    }

    public void setGradientView(INDArray gradientView) {
        this.gradientView = gradientView;
    }

    public void setUpdaterViewRequiresInitialization(boolean updaterViewRequiresInitialization) {
        this.updaterViewRequiresInitialization = updaterViewRequiresInitialization;
    }

    public void setGradientUpdater(GradientUpdater gradientUpdater) {
        this.gradientUpdater = gradientUpdater;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof UpdaterBlock)) {
            return false;
        }
        UpdaterBlock other = (UpdaterBlock)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getParamOffsetStart() != other.getParamOffsetStart()) {
            return false;
        }
        if (this.getParamOffsetEnd() != other.getParamOffsetEnd()) {
            return false;
        }
        if (this.getUpdaterViewOffsetStart() != other.getUpdaterViewOffsetStart()) {
            return false;
        }
        if (this.getUpdaterViewOffsetEnd() != other.getUpdaterViewOffsetEnd()) {
            return false;
        }
        List<ParamState> this$layersAndVariablesInBlock = this.getLayersAndVariablesInBlock();
        List<ParamState> other$layersAndVariablesInBlock = other.getLayersAndVariablesInBlock();
        if (this$layersAndVariablesInBlock == null ? other$layersAndVariablesInBlock != null : !((Object)this$layersAndVariablesInBlock).equals(other$layersAndVariablesInBlock)) {
            return false;
        }
        INDArray this$updaterView = this.getUpdaterView();
        INDArray other$updaterView = other.getUpdaterView();
        if (this$updaterView == null ? other$updaterView != null : !this$updaterView.equals(other$updaterView)) {
            return false;
        }
        INDArray this$gradientView = this.getGradientView();
        INDArray other$gradientView = other.getGradientView();
        if (this$gradientView == null ? other$gradientView != null : !this$gradientView.equals(other$gradientView)) {
            return false;
        }
        if (this.isUpdaterViewRequiresInitialization() != other.isUpdaterViewRequiresInitialization()) {
            return false;
        }
        GradientUpdater this$gradientUpdater = this.getGradientUpdater();
        GradientUpdater other$gradientUpdater = other.getGradientUpdater();
        return !(this$gradientUpdater == null ? other$gradientUpdater != null : !this$gradientUpdater.equals(other$gradientUpdater));
    }

    protected boolean canEqual(Object other) {
        return other instanceof UpdaterBlock;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getParamOffsetStart();
        result = result * 59 + this.getParamOffsetEnd();
        result = result * 59 + this.getUpdaterViewOffsetStart();
        result = result * 59 + this.getUpdaterViewOffsetEnd();
        List<ParamState> $layersAndVariablesInBlock = this.getLayersAndVariablesInBlock();
        result = result * 59 + ($layersAndVariablesInBlock == null ? 43 : ((Object)$layersAndVariablesInBlock).hashCode());
        INDArray $updaterView = this.getUpdaterView();
        result = result * 59 + ($updaterView == null ? 43 : $updaterView.hashCode());
        INDArray $gradientView = this.getGradientView();
        result = result * 59 + ($gradientView == null ? 43 : $gradientView.hashCode());
        result = result * 59 + (this.isUpdaterViewRequiresInitialization() ? 79 : 97);
        GradientUpdater $gradientUpdater = this.getGradientUpdater();
        result = result * 59 + ($gradientUpdater == null ? 43 : $gradientUpdater.hashCode());
        return result;
    }

    public String toString() {
        return "UpdaterBlock(paramOffsetStart=" + this.getParamOffsetStart() + ", paramOffsetEnd=" + this.getParamOffsetEnd() + ", updaterViewOffsetStart=" + this.getUpdaterViewOffsetStart() + ", updaterViewOffsetEnd=" + this.getUpdaterViewOffsetEnd() + ", layersAndVariablesInBlock=" + this.getLayersAndVariablesInBlock() + ", updaterView=" + this.getUpdaterView() + ", gradientView=" + this.getGradientView() + ", updaterViewRequiresInitialization=" + this.isUpdaterViewRequiresInitialization() + ", gradientUpdater=" + this.getGradientUpdater() + ")";
    }

    public static class ParamState {
        private final Layer layer;
        private final String paramName;
        private final int paramOffsetStart;
        private final int paramOffsetEnd;
        private final INDArray paramView;
        private final INDArray gradView;

        @ConstructorProperties(value={"layer", "paramName", "paramOffsetStart", "paramOffsetEnd", "paramView", "gradView"})
        public ParamState(Layer layer, String paramName, int paramOffsetStart, int paramOffsetEnd, INDArray paramView, INDArray gradView) {
            this.layer = layer;
            this.paramName = paramName;
            this.paramOffsetStart = paramOffsetStart;
            this.paramOffsetEnd = paramOffsetEnd;
            this.paramView = paramView;
            this.gradView = gradView;
        }

        public Layer getLayer() {
            return this.layer;
        }

        public String getParamName() {
            return this.paramName;
        }

        public int getParamOffsetStart() {
            return this.paramOffsetStart;
        }

        public int getParamOffsetEnd() {
            return this.paramOffsetEnd;
        }

        public INDArray getParamView() {
            return this.paramView;
        }

        public INDArray getGradView() {
            return this.gradView;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ParamState)) {
                return false;
            }
            ParamState other = (ParamState)o;
            if (!other.canEqual(this)) {
                return false;
            }
            Layer this$layer = this.getLayer();
            Layer other$layer = other.getLayer();
            if (this$layer == null ? other$layer != null : !this$layer.equals(other$layer)) {
                return false;
            }
            String this$paramName = this.getParamName();
            String other$paramName = other.getParamName();
            if (this$paramName == null ? other$paramName != null : !this$paramName.equals(other$paramName)) {
                return false;
            }
            if (this.getParamOffsetStart() != other.getParamOffsetStart()) {
                return false;
            }
            if (this.getParamOffsetEnd() != other.getParamOffsetEnd()) {
                return false;
            }
            INDArray this$paramView = this.getParamView();
            INDArray other$paramView = other.getParamView();
            if (this$paramView == null ? other$paramView != null : !this$paramView.equals(other$paramView)) {
                return false;
            }
            INDArray this$gradView = this.getGradView();
            INDArray other$gradView = other.getGradView();
            return !(this$gradView == null ? other$gradView != null : !this$gradView.equals(other$gradView));
        }

        protected boolean canEqual(Object other) {
            return other instanceof ParamState;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            Layer $layer = this.getLayer();
            result = result * 59 + ($layer == null ? 43 : $layer.hashCode());
            String $paramName = this.getParamName();
            result = result * 59 + ($paramName == null ? 43 : $paramName.hashCode());
            result = result * 59 + this.getParamOffsetStart();
            result = result * 59 + this.getParamOffsetEnd();
            INDArray $paramView = this.getParamView();
            result = result * 59 + ($paramView == null ? 43 : $paramView.hashCode());
            INDArray $gradView = this.getGradView();
            result = result * 59 + ($gradView == null ? 43 : $gradView.hashCode());
            return result;
        }

        public String toString() {
            return "UpdaterBlock.ParamState(layer=" + this.getLayer() + ", paramName=" + this.getParamName() + ", paramOffsetStart=" + this.getParamOffsetStart() + ", paramOffsetEnd=" + this.getParamOffsetEnd() + ", paramView=" + this.getParamView() + ", gradView=" + this.getGradView() + ")";
        }
    }
}

