/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.termination;

import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ScoreImprovementEpochTerminationCondition
implements EpochTerminationCondition {
    private static final Logger log = LoggerFactory.getLogger(ScoreImprovementEpochTerminationCondition.class);
    @JsonProperty
    private int maxEpochsWithNoImprovement;
    @JsonProperty
    private int bestEpoch = -1;
    @JsonProperty
    private double bestScore;
    @JsonProperty
    private double minImprovement = 0.0;

    public ScoreImprovementEpochTerminationCondition(int maxEpochsWithNoImprovement) {
        this.maxEpochsWithNoImprovement = maxEpochsWithNoImprovement;
    }

    public ScoreImprovementEpochTerminationCondition(int maxEpochsWithNoImprovement, double minImprovement) {
        this.maxEpochsWithNoImprovement = maxEpochsWithNoImprovement;
        this.minImprovement = minImprovement;
    }

    @Override
    public void initialize() {
        this.bestEpoch = -1;
        this.bestScore = Double.NaN;
    }

    @Override
    public boolean terminate(int epochNum, double score) {
        if (this.bestEpoch == -1) {
            this.bestEpoch = epochNum;
            this.bestScore = score;
            return false;
        }
        double improvement = this.bestScore - score;
        if (improvement > this.minImprovement) {
            if (this.minImprovement > 0.0) {
                log.info("Epoch with score greater than threshold * * *");
            }
            this.bestScore = score;
            this.bestEpoch = epochNum;
            return false;
        }
        return epochNum >= this.bestEpoch + this.maxEpochsWithNoImprovement;
    }

    public String toString() {
        return "ScoreImprovementEpochTerminationCondition(maxEpochsWithNoImprovement=" + this.maxEpochsWithNoImprovement + ", minImprovement=" + this.minImprovement + ")";
    }

    public int getMaxEpochsWithNoImprovement() {
        return this.maxEpochsWithNoImprovement;
    }

    public int getBestEpoch() {
        return this.bestEpoch;
    }

    public double getBestScore() {
        return this.bestScore;
    }

    public double getMinImprovement() {
        return this.minImprovement;
    }

    public void setMaxEpochsWithNoImprovement(int maxEpochsWithNoImprovement) {
        this.maxEpochsWithNoImprovement = maxEpochsWithNoImprovement;
    }

    public void setBestEpoch(int bestEpoch) {
        this.bestEpoch = bestEpoch;
    }

    public void setBestScore(double bestScore) {
        this.bestScore = bestScore;
    }

    public void setMinImprovement(double minImprovement) {
        this.minImprovement = minImprovement;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ScoreImprovementEpochTerminationCondition)) {
            return false;
        }
        ScoreImprovementEpochTerminationCondition other = (ScoreImprovementEpochTerminationCondition)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getMaxEpochsWithNoImprovement() != other.getMaxEpochsWithNoImprovement()) {
            return false;
        }
        if (this.getBestEpoch() != other.getBestEpoch()) {
            return false;
        }
        if (Double.compare(this.getBestScore(), other.getBestScore()) != 0) {
            return false;
        }
        return Double.compare(this.getMinImprovement(), other.getMinImprovement()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof ScoreImprovementEpochTerminationCondition;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getMaxEpochsWithNoImprovement();
        result = result * 59 + this.getBestEpoch();
        long $bestScore = Double.doubleToLongBits(this.getBestScore());
        result = result * 59 + (int)($bestScore >>> 32 ^ $bestScore);
        long $minImprovement = Double.doubleToLongBits(this.getMinImprovement());
        result = result * 59 + (int)($minImprovement >>> 32 ^ $minImprovement);
        return result;
    }
}

