/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.misc.DummyConfig;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.OneTimeLogger;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FrozenLayer
extends BaseWrapperLayer {
    private static final Logger log = LoggerFactory.getLogger(FrozenLayer.class);
    private boolean logUpdate = false;
    private boolean logFit = false;
    private boolean logTestMode = false;
    private boolean logGradient = false;
    private Gradient zeroGradient;
    private transient DummyConfig config;

    public FrozenLayer(Layer insideLayer) {
        super(insideLayer);
        if (insideLayer instanceof OutputLayer) {
            throw new IllegalArgumentException("Output Layers are not allowed to be frozen " + this.layerId());
        }
        this.zeroGradient = new DefaultGradient(insideLayer.params());
        if (insideLayer.paramTable() != null) {
            for (String paramType : insideLayer.paramTable().keySet()) {
                this.zeroGradient.setGradientFor(paramType, null);
            }
        }
    }

    @Override
    public void setCacheMode(CacheMode mode) {
    }

    protected String layerId() {
        String name = this.underlying.conf().getLayer().getLayerName();
        return "(layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + this.underlying.getIndex() + ")";
    }

    @Override
    public double calcRegularizationScore(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        return new Pair((Object)this.zeroGradient, null);
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.logTestMode(training);
        return this.underlying.activate(false, workspaceMgr);
    }

    @Override
    public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.logTestMode(training);
        return this.underlying.activate(input, false, workspaceMgr);
    }

    @Override
    public void fit() {
        if (!this.logFit) {
            OneTimeLogger.info((Logger)log, (String)"Frozen layers cannot be fit. Warning will be issued only once per instance", (Object[])new Object[0]);
            this.logFit = true;
        }
    }

    @Override
    public void update(Gradient gradient) {
        if (!this.logUpdate) {
            OneTimeLogger.info((Logger)log, (String)"Frozen layers will not be updated. Warning will be issued only once per instance", (Object[])new Object[0]);
            this.logUpdate = true;
        }
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        if (!this.logUpdate) {
            OneTimeLogger.info((Logger)log, (String)"Frozen layers will not be updated. Warning will be issued only once per instance", (Object[])new Object[0]);
            this.logUpdate = true;
        }
    }

    @Override
    public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {
        if (!this.logGradient) {
            OneTimeLogger.info((Logger)log, (String)"Gradients for the frozen layer are not set and will therefore will not be updated.Warning will be issued only once per instance", (Object[])new Object[0]);
            this.logGradient = true;
        }
        this.underlying.score();
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        if (!this.logGradient) {
            OneTimeLogger.info((Logger)log, (String)"Gradients for the frozen layer are not set and will therefore will not be updated.Warning will be issued only once per instance", (Object[])new Object[0]);
            this.logGradient = true;
        }
    }

    @Override
    public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) {
        if (!this.logFit) {
            OneTimeLogger.info((Logger)log, (String)"Frozen layers cannot be fit.Warning will be issued only once per instance", (Object[])new Object[0]);
            this.logFit = true;
        }
    }

    @Override
    public Gradient gradient() {
        return this.zeroGradient;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        if (!this.logGradient) {
            OneTimeLogger.info((Logger)log, (String)"Gradients for the frozen layer are not set and will therefore will not be updated.Warning will be issued only once per instance", (Object[])new Object[0]);
            this.logGradient = true;
        }
        return new Pair((Object)this.zeroGradient, (Object)this.underlying.score());
    }

    @Override
    public void applyConstraints(int iteration, int epoch) {
    }

    @Override
    public void init() {
    }

    public void logTestMode(boolean training) {
        if (!training) {
            return;
        }
        if (this.logTestMode) {
            return;
        }
        OneTimeLogger.info((Logger)log, (String)"Frozen layer instance found! Frozen layers are treated as always in test mode. Warning will only be issued once per instance", (Object[])new Object[0]);
        this.logTestMode = true;
    }

    public void logTestMode(Layer.TrainingMode training) {
        if (training.equals((Object)Layer.TrainingMode.TEST)) {
            return;
        }
        if (this.logTestMode) {
            return;
        }
        OneTimeLogger.info((Logger)log, (String)"Frozen layer instance found! Frozen layers are treated as always in test mode. Warning will only be issued once per instance", (Object[])new Object[0]);
        this.logTestMode = true;
    }

    public Layer getInsideLayer() {
        return this.underlying;
    }

    @Override
    public TrainingConfig getConfig() {
        if (this.config == null) {
            this.config = new DummyConfig(this.getUnderlying().getConfig().getLayerName());
        }
        return this.config;
    }
}

