/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold;

import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithmReducer;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TargetSparsityThresholdAlgorithm
implements ThresholdAlgorithm {
    private static final Logger log = LoggerFactory.getLogger(TargetSparsityThresholdAlgorithm.class);
    public static final double DEFAULT_INITIAL_THRESHOLD = 1.0E-4;
    public static final double DEFAULT_SPARSITY_TARGET = 0.001;
    public static final double DEFAULT_DECAY_RATE = Math.pow(0.5, 0.05);
    private final double initialThreshold;
    private final double sparsityTarget;
    private final double decayRate;
    private double lastThreshold = Double.NaN;
    private double lastSparsity = Double.NaN;

    public TargetSparsityThresholdAlgorithm() {
        this(1.0E-4, 0.001, DEFAULT_DECAY_RATE);
    }

    public TargetSparsityThresholdAlgorithm(double initialThreshold, double sparsityTarget, double decayRate) {
        Preconditions.checkArgument((initialThreshold > 0.0 ? 1 : 0) != 0, (String)"Initial threshold must be positive. Got: %s", (double)initialThreshold);
        Preconditions.checkState((sparsityTarget > 0.0 && sparsityTarget < 0.0625 ? 1 : 0) != 0, (String)"Sparsity target must be between 0 (exclusive) and 1.0/16 (inclusive), got %s", (double)sparsityTarget);
        Preconditions.checkArgument((decayRate >= 0.5 && decayRate < 1.0 ? 1 : 0) != 0, (String)"Decay rate must be a number in range 0.5 (inclusive) to 1.0 (exclusive). Usually decay rate is in range 0.95 to 0.999. Got decay rate: %s", (double)decayRate);
        this.initialThreshold = initialThreshold;
        this.sparsityTarget = sparsityTarget;
        this.decayRate = decayRate;
    }

    @Override
    public double calculateThreshold(int iteration, int epoch, Double lastThreshold, Boolean lastWasDense, Double lastSparsityRatio, INDArray updatesPlusResidual) {
        double prevSparsity;
        double adaptFromThreshold;
        if (lastThreshold == null && Double.isNaN(this.lastThreshold)) {
            this.lastThreshold = this.initialThreshold;
            return this.initialThreshold;
        }
        double d = adaptFromThreshold = lastThreshold != null ? lastThreshold : this.lastThreshold;
        if (lastSparsityRatio != null) {
            prevSparsity = lastSparsityRatio;
        } else if (lastWasDense != null && lastWasDense.booleanValue()) {
            prevSparsity = 0.0625;
        } else if (!Double.isNaN(this.lastSparsity)) {
            prevSparsity = this.lastSparsity;
        } else {
            throw new IllegalStateException("Unexpected state: not first iteration but no last sparsity value is available: iteration=" + iteration + ", epoch=" + epoch + ", lastThreshold=" + lastThreshold + ", lastWasDense=" + lastWasDense + ", lastSparsityRatio=" + lastSparsityRatio + ", this.lastSparsity=" + this.lastSparsity);
        }
        this.lastSparsity = prevSparsity;
        if (prevSparsity < this.sparsityTarget) {
            double retThreshold;
            this.lastThreshold = retThreshold = this.decayRate * adaptFromThreshold;
            if (log.isDebugEnabled()) {
                log.debug("TargetSparsityThresholdAlgorithm: iter {} epoch {}: prev sparsity {} < target sparsity {}, reducing threshold from {} to  {}", new Object[]{iteration, epoch, prevSparsity, this.sparsityTarget, adaptFromThreshold, retThreshold});
            }
            return retThreshold;
        }
        if (prevSparsity > this.sparsityTarget) {
            double retThreshold;
            this.lastThreshold = retThreshold = 1.0 / this.decayRate * adaptFromThreshold;
            if (log.isDebugEnabled()) {
                log.debug("TargetSparsityThresholdAlgorithm: iter {} epoch {}: prev sparsity {} > max sparsity {}, increasing threshold from {} to  {}", new Object[]{iteration, epoch, prevSparsity, this.sparsityTarget, adaptFromThreshold, retThreshold});
            }
            return retThreshold;
        }
        if (log.isDebugEnabled()) {
            log.debug("TargetSparsityThresholdAlgorithm: keeping existing threshold of {}, previous sparsity {}, target sparsity {}", new Object[]{adaptFromThreshold, prevSparsity, this.sparsityTarget});
        }
        this.lastThreshold = adaptFromThreshold;
        return adaptFromThreshold;
    }

    @Override
    public ThresholdAlgorithmReducer newReducer() {
        return new Reducer(this.initialThreshold, this.sparsityTarget, this.decayRate);
    }

    @Override
    public TargetSparsityThresholdAlgorithm clone() {
        TargetSparsityThresholdAlgorithm ret = new TargetSparsityThresholdAlgorithm(this.initialThreshold, this.sparsityTarget, this.decayRate);
        ret.lastThreshold = this.lastThreshold;
        ret.lastSparsity = this.lastSparsity;
        return ret;
    }

    public String toString() {
        String s = "TargetSparsityThresholdAlgorithm(initialThreshold=" + this.initialThreshold + ",targetSparsity=" + this.sparsityTarget + ",decayRate=" + this.decayRate;
        if (Double.isNaN(this.lastThreshold)) {
            return s + ")";
        }
        return s + ",lastThreshold=" + this.lastThreshold + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TargetSparsityThresholdAlgorithm)) {
            return false;
        }
        TargetSparsityThresholdAlgorithm other = (TargetSparsityThresholdAlgorithm)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.initialThreshold, other.initialThreshold) != 0) {
            return false;
        }
        if (Double.compare(this.sparsityTarget, other.sparsityTarget) != 0) {
            return false;
        }
        return Double.compare(this.decayRate, other.decayRate) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof TargetSparsityThresholdAlgorithm;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $initialThreshold = Double.doubleToLongBits(this.initialThreshold);
        result = result * 59 + (int)($initialThreshold >>> 32 ^ $initialThreshold);
        long $sparsityTarget = Double.doubleToLongBits(this.sparsityTarget);
        result = result * 59 + (int)($sparsityTarget >>> 32 ^ $sparsityTarget);
        long $decayRate = Double.doubleToLongBits(this.decayRate);
        result = result * 59 + (int)($decayRate >>> 32 ^ $decayRate);
        return result;
    }

    public double getLastThreshold() {
        return this.lastThreshold;
    }

    public double getLastSparsity() {
        return this.lastSparsity;
    }

    private static class Reducer
    implements ThresholdAlgorithmReducer {
        private final double initialThreshold;
        private final double targetSparsity;
        private final double decayRate;
        private double lastThresholdSum;
        private double lastSparsitySum;
        private int count;

        private Reducer(double initialThreshold, double targetSparsity, double decayRate) {
            this.initialThreshold = initialThreshold;
            this.targetSparsity = targetSparsity;
            this.decayRate = decayRate;
        }

        @Override
        public void add(ThresholdAlgorithm instance) {
            TargetSparsityThresholdAlgorithm a = (TargetSparsityThresholdAlgorithm)instance;
            if (a == null || Double.isNaN(a.lastThreshold)) {
                return;
            }
            this.lastThresholdSum += a.lastThreshold;
            this.lastSparsitySum += a.lastSparsity;
            ++this.count;
        }

        @Override
        public ThresholdAlgorithmReducer merge(ThresholdAlgorithmReducer other) {
            Reducer r = (Reducer)other;
            this.lastThresholdSum += r.lastThresholdSum;
            this.lastSparsitySum += r.lastSparsitySum;
            this.count += r.count;
            return this;
        }

        @Override
        public ThresholdAlgorithm getFinalResult() {
            TargetSparsityThresholdAlgorithm ret = new TargetSparsityThresholdAlgorithm(this.initialThreshold, this.targetSparsity, this.decayRate);
            if (this.count > 0) {
                ret.lastThreshold = this.lastThresholdSum / (double)this.count;
                ret.lastSparsity = this.lastSparsitySum / (double)this.count;
            }
            return ret;
        }
    }
}

