/*
 * Decompiled with CFR 0.152.
 */
package smile.math;

import java.io.Serializable;
import java.util.Arrays;

public interface TimeFunction
extends Serializable {
    public double apply(int var1);

    public static TimeFunction constant(final double alpha) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                return alpha;
            }

            public String toString() {
                return String.format("Constant(%f)", alpha);
            }
        };
    }

    public static TimeFunction piecewise(final int[] boundaries, final double[] values) {
        if (values.length != boundaries.length + 1) {
            throw new IllegalArgumentException("values should have one more element than boundaries");
        }
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                int i = Arrays.binarySearch(boundaries, t2);
                if (i < 0) {
                    i = -i - 1;
                }
                return values[i];
            }

            public String toString() {
                return String.format("PiecewiseConstant(%s, %s)", Arrays.toString(boundaries), Arrays.toString(values));
            }
        };
    }

    public static TimeFunction linear(double initLearningRate, double decaySteps) {
        return TimeFunction.linear(initLearningRate, decaySteps, 1.0E-4);
    }

    public static TimeFunction linear(double initLearningRate, double decaySteps, double endLearningRate) {
        return TimeFunction.polynomial(initLearningRate, decaySteps, endLearningRate, false, 1.0);
    }

    public static TimeFunction polynomial(final double initLearningRate, final double decaySteps, final double endLearningRate, final boolean cycle, final double power) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                if (cycle) {
                    double T = decaySteps * Math.max(1.0, Math.ceil((double)t2 / decaySteps));
                    return (initLearningRate - endLearningRate) * Math.pow(1.0 - (double)t2 / T, power) + endLearningRate;
                }
                double steps = Math.min((double)t2, decaySteps);
                return (initLearningRate - endLearningRate) * Math.pow(1.0 - steps / decaySteps, power) + endLearningRate;
            }

            public String toString() {
                return String.format("PolynomialDecay(initial learning rate = %f, decay steps = %.0f, end learning rate = %f, cycle = %s, power = %f)", initLearningRate, decaySteps, endLearningRate, cycle, power);
            }
        };
    }

    public static TimeFunction inverse(final double initLearningRate, final double decaySteps) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                return initLearningRate * decaySteps / (decaySteps + (double)t2);
            }

            public String toString() {
                return String.format("InverseTimeDecay(initial learning rate = %f, decaySteps = %.0f)", initLearningRate, decaySteps);
            }
        };
    }

    public static TimeFunction inverse(double initLearningRate, double decaySteps, double decayRate) {
        return TimeFunction.inverse(initLearningRate, decaySteps, decayRate, false);
    }

    public static TimeFunction inverse(final double initLearningRate, final double decaySteps, final double decayRate, final boolean staircase) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                if (staircase) {
                    return initLearningRate / (1.0 + decayRate * Math.floor((double)t2 / decaySteps));
                }
                return initLearningRate / (1.0 + decayRate * (double)t2 / decaySteps);
            }

            public String toString() {
                return String.format("InverseTimeDecay(initial learning rate = %f, decay steps = %.0f, decay rate = %f, staircase = %s)", initLearningRate, decaySteps, decayRate, staircase);
            }
        };
    }

    public static TimeFunction exp(final double initLearningRate, final double decaySteps) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                return initLearningRate * Math.exp((double)(-t2) / decaySteps);
            }

            public String toString() {
                return String.format("ExponentialDecay(initial learning rate = %f, decay steps = %.0f)", initLearningRate, decaySteps);
            }
        };
    }

    public static TimeFunction exp(final double initLearningRate, final double decaySteps, final double endLearningRate) {
        final double decayRate = endLearningRate / initLearningRate;
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                return initLearningRate * Math.pow(decayRate, Math.min((double)t2, decaySteps) / decaySteps);
            }

            public String toString() {
                return String.format("ExponentialDecay(initial learning rate = %f, decay steps = %.0f, end learning rate = %f)", initLearningRate, decaySteps, endLearningRate);
            }
        };
    }

    public static TimeFunction exp(final double initLearningRate, final double decaySteps, final double decayRate, final boolean staircase) {
        return new TimeFunction(){

            @Override
            public double apply(int t2) {
                if (staircase) {
                    return initLearningRate * Math.pow(decayRate, Math.floor((double)t2 / decaySteps));
                }
                return initLearningRate * Math.pow(decayRate, (double)t2 / decaySteps);
            }

            public String toString() {
                return String.format("ExponentialDecay(initial learning rate = %f, decay steps = %.0f, decay rate = %f, staircase = %s)", initLearningRate, decaySteps, decayRate, staircase);
            }
        };
    }
}

