/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.learning.regularization;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.schedule.FixedSchedule;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class WeightDecay
implements Regularization {
    protected final ISchedule coeff;
    protected final boolean applyLR;

    public WeightDecay(double coeff, boolean applyLR) {
        this(new FixedSchedule(coeff), applyLR);
    }

    public WeightDecay(@JsonProperty(value="coeff") @NonNull ISchedule coeff, @JsonProperty(value="applyLR") boolean applyLR) {
        if (coeff == null) {
            throw new NullPointerException("coeff is marked non-null but is null");
        }
        this.coeff = coeff;
        this.applyLR = applyLR;
    }

    @Override
    public Regularization.ApplyStep applyStep() {
        return Regularization.ApplyStep.POST_UPDATER;
    }

    @Override
    public void apply(INDArray param, INDArray gradView, double lr, int iteration, int epoch) {
        double scale = this.coeff.valueAt(iteration, epoch);
        if (this.applyLR) {
            scale *= lr;
        }
        Nd4j.exec(new Axpy(param, gradView, gradView, scale));
    }

    @Override
    public double score(INDArray param, int iteration, int epoch) {
        double norm2 = param.norm2Number().doubleValue();
        return this.coeff.valueAt(iteration, epoch) * 0.5 * norm2 * norm2;
    }

    @Override
    public Regularization clone() {
        return new WeightDecay(this.coeff.clone(), this.applyLR);
    }

    public ISchedule getCoeff() {
        return this.coeff;
    }

    public boolean isApplyLR() {
        return this.applyLR;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof WeightDecay)) {
            return false;
        }
        WeightDecay other = (WeightDecay)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.isApplyLR() != other.isApplyLR()) {
            return false;
        }
        ISchedule this$coeff = this.getCoeff();
        ISchedule other$coeff = other.getCoeff();
        return !(this$coeff == null ? other$coeff != null : !this$coeff.equals(other$coeff));
    }

    protected boolean canEqual(Object other) {
        return other instanceof WeightDecay;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + (this.isApplyLR() ? 79 : 97);
        ISchedule $coeff = this.getCoeff();
        result = result * 59 + ($coeff == null ? 43 : $coeff.hashCode());
        return result;
    }

    public String toString() {
        return "WeightDecay(coeff=" + this.getCoeff() + ", applyLR=" + this.isApplyLR() + ")";
    }
}

