/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.optimizer.learningrate;

import ai.djl.TrainingDivergedException;
import ai.djl.training.optimizer.learningrate.FactorTracker;
import ai.djl.training.optimizer.learningrate.FixedLearningRate;
import ai.djl.training.optimizer.learningrate.MultiFactorTracker;
import ai.djl.training.optimizer.learningrate.WarmUpMode;

public abstract class LearningRateTracker {
    float baseLearningRate;
    int warmUpSteps;
    float warmUpBeginLearningRate;
    float warmUpFinalLearningRate;
    WarmUpMode warmUpMode;

    LearningRateTracker(LrBaseBuilder<?> builder) {
        this.baseLearningRate = builder.baseLearningRate;
        this.warmUpSteps = builder.warmUpSteps;
        this.warmUpBeginLearningRate = builder.warmUpBeginLearningRate;
        this.warmUpMode = builder.warmUpMode;
        this.warmUpFinalLearningRate = this.baseLearningRate;
    }

    float getWarmUpLearningRate(int numUpdate) {
        float learningRate = this.warmUpBeginLearningRate;
        if (this.warmUpMode == WarmUpMode.LINEAR) {
            learningRate = this.warmUpBeginLearningRate + (this.warmUpFinalLearningRate - this.warmUpBeginLearningRate) * (float)numUpdate / (float)this.warmUpSteps;
        }
        this.checkLearningRate(learningRate);
        return learningRate;
    }

    public abstract float getNewLearningRate(int var1);

    void checkLearningRate(float learningRate) {
        if (Float.isNaN(learningRate)) {
            throw new TrainingDivergedException("Learning rate is Nan.");
        }
    }

    public static FactorTracker.Builder factorTracker() {
        return new FactorTracker.Builder();
    }

    public static MultiFactorTracker.Builder multiFactorTracker() {
        return new MultiFactorTracker.Builder();
    }

    public static FixedLearningRate fixedLearningRate(float learningRate) {
        return ((FixedLearningRate.Builder)FixedLearningRate.builder().optBaseLearningRate(learningRate)).build();
    }

    public static abstract class LrBaseBuilder<T extends LrBaseBuilder> {
        float baseLearningRate = 0.01f;
        int warmUpSteps;
        float warmUpBeginLearningRate;
        WarmUpMode warmUpMode = WarmUpMode.LINEAR;

        public T optBaseLearningRate(float baseLearningRate) {
            this.baseLearningRate = baseLearningRate;
            return this.self();
        }

        public T optWarmUpSteps(int warmUpSteps) {
            this.warmUpSteps = warmUpSteps;
            return this.self();
        }

        public T optWarmUpBeginLearningRate(float warmUpBeginLearningRate) {
            this.warmUpBeginLearningRate = warmUpBeginLearningRate;
            return this.self();
        }

        public T optWarmUpMode(WarmUpMode warmUpMode) {
            this.warmUpMode = warmUpMode;
            return this.self();
        }

        protected abstract T self();
    }
}

