/*
 * Decompiled with CFR 0.152.
 */
package develop.p2p.lib;

import java.util.Arrays;

public class LearnMath {
    public static double sigmoid(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    public static double sigmoidDef(double x) {
        return LearnMath.sigmoid(x) * (1.0 - LearnMath.sigmoid(x));
    }

    public static double swish(double x) {
        return x * LearnMath.sigmoid(x);
    }

    public static double swishDef(double x) {
        return LearnMath.swish(x) + LearnMath.sigmoid(x) * (1.0 - LearnMath.swish(x));
    }

    public static double step(double x) {
        return x >= 0.0 ? 1.0 : 0.0;
    }

    public static double relu(double x) {
        return x * Math.max(x, 0.0);
    }

    public static double reluDef(double x) {
        return 1.0 * x > 0.0 ? 1.0 : 0.0;
    }

    public static double lrelu(double x) {
        return x >= 0.0 ? x : 0.01 * x;
    }

    public static double lrelu(double x, double alpha) {
        return x >= 0.0 ? x : alpha * x;
    }

    public static double lreluDef(double x) {
        return x >= 0.0 ? 1.0 : 0.01;
    }

    public static double lreluDef(double x, double alpha) {
        return x >= 0.0 ? alpha : 0.01;
    }

    public static double elu(double x) {
        return x > 0.0 ? x : 1.0 * (Math.exp(x) - 1.0);
    }

    public static double elu(double x, double alpha) {
        return x > 0.0 ? x : alpha * (Math.exp(x) - 1.0);
    }

    public static double eluDef(double x) {
        return x > 0.0 ? 1.0 : LearnMath.elu(x) + 1.0;
    }

    public static double eluDef(double x, double alpha) {
        return x > 0.0 ? 1.0 : LearnMath.elu(x, alpha) + alpha;
    }

    public static double selu(double x, double scale, double alpha) {
        return scale * x > 0.0 ? x : alpha * (Math.exp(x) - 1.0);
    }

    public static double seluDef(double x, double scale, double alpha) {
        return scale * x > 0.0 ? 1.0 : alpha * Math.exp(x);
    }

    public static double tanH(double x) {
        return (Math.exp(x) - Math.exp(-x)) / (Math.exp(x) + Math.exp(-x));
    }

    public static double tanHDef(double x) {
        return 1.0 - Math.pow(LearnMath.tanH(x), 2.0);
    }

    public static double softplus(double x) {
        return Math.log(1.0 + Math.exp(x));
    }

    public static double softplusDef(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    public static double omega(double x) {
        return 4.0 * (x + 1.0) + 4.0 * Math.exp(2.0 * x) + Math.exp(3.0 * x) + Math.exp(x) * (4.0 * x + 6.0);
    }

    public static double delta(double x) {
        return 2.0 * Math.exp(x) + Math.exp(2.0 * x) + 2.0;
    }

    public static double mish(double x) {
        return x * LearnMath.tanH(LearnMath.softplus(x));
    }

    public static double mishDef(double x) {
        return Math.exp(x) * LearnMath.omega(x) / Math.pow(LearnMath.delta(x), 2.0);
    }

    public static double identity(double x) {
        return x;
    }

    public static double identityDef() {
        return 1.0;
    }

    public static double[] softmax(double[] x) {
        double[] value = Arrays.stream(x).map(y -> Math.exp(y - Arrays.stream(x).max().getAsDouble())).toArray();
        return Arrays.stream(value).map(p -> p / Arrays.stream(value).sum()).toArray();
    }

    public static double[][] softmax(double[][] x) {
        double[][] result = new double[x.length][];
        Arrays.setAll(result, i -> LearnMath.softmax(x[i]));
        return result;
    }
}

