/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.optimizer.learningrate;

import ai.djl.training.optimizer.learningrate.FactorTracker;
import ai.djl.training.optimizer.learningrate.LearningRateTracker;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiFactorTracker
extends LearningRateTracker {
    private static final Logger logger = LoggerFactory.getLogger(FactorTracker.class);
    private int[] steps;
    private float factor;
    private int stepIndex;

    public MultiFactorTracker(Builder builder) {
        super(builder);
        this.steps = builder.steps;
        this.factor = builder.factor;
    }

    @Override
    public float getNewLearningRate(int numUpdate) {
        if (numUpdate < this.warmUpSteps) {
            return this.getWarmUpLearningRate(numUpdate);
        }
        while (this.stepIndex <= this.steps.length - 1) {
            if (numUpdate > this.steps[this.stepIndex]) {
                ++this.stepIndex;
                this.baseLearningRate *= this.factor;
                logger.debug("Update[{}]: Change learning rate to {}", (Object)numUpdate, (Object)String.format("%.5e", Float.valueOf(this.baseLearningRate)));
                continue;
            }
            this.checkLearningRate(this.baseLearningRate);
            return this.baseLearningRate;
        }
        this.checkLearningRate(this.baseLearningRate);
        return this.baseLearningRate;
    }

    public static final class Builder
    extends LearningRateTracker.LrBaseBuilder<Builder> {
        private int[] steps;
        private float factor = 1.0f;

        @Override
        protected Builder self() {
            return this;
        }

        public Builder setSteps(int[] steps) {
            if (steps.length <= 1) {
                throw new IllegalArgumentException("Steps should be an array of integers indicating when the learning rate should be changed, usually in an uneven interval of stepsuse FactorTracker if you want learning rate to be changed at a constant interval of steps");
            }
            for (int i = 0; i < steps.length; ++i) {
                if (i > 0 && steps[i] <= steps[i - 1]) {
                    throw new IllegalArgumentException("Steps must be an increasing list");
                }
                if (steps[i] >= 1) continue;
                throw new IllegalArgumentException("Step must be larger or equal to 1");
            }
            this.steps = steps;
            return this;
        }

        public Builder optFactor(float factor) {
            if (factor > 1.0f) {
                throw new IllegalArgumentException("factor should be no more than 1");
            }
            this.factor = factor;
            return this;
        }

        public MultiFactorTracker build() {
            if (this.steps == null) {
                throw new IllegalArgumentException("Steps must be set to change learning rate");
            }
            return new MultiFactorTracker(this);
        }
    }
}

