/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.activations.impl;

import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.ELUDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;

public class ActivationELU
extends BaseActivationFunction {
    public static final double DEFAULT_ALPHA = 1.0;
    private double alpha = 1.0;

    public ActivationELU() {
        this(1.0);
    }

    public ActivationELU(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public INDArray getActivation(INDArray in, boolean training) {
        if (this.alpha != 1.0) {
            INDArray alphaMultiple = Nd4j.getExecutioner().execAndReturn(new ELU(in.dup()));
            alphaMultiple.muli(this.alpha);
            BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
        } else {
            Nd4j.getExecutioner().execAndReturn(new ELU(in));
        }
        return in;
    }

    @Override
    public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
        if (this.alpha != 1.0) {
            INDArray dLdz = Nd4j.getExecutioner().execAndReturn(new ELUDerivative(in.dup()));
            dLdz.muli(this.alpha);
            BooleanIndexing.replaceWhere(dLdz, 1, Conditions.equals(this.alpha));
            dLdz.muli(epsilon);
            return new Pair((Object)dLdz, null);
        }
        INDArray dLdz = Nd4j.getExecutioner().execAndReturn(new ELU(in).derivative());
        dLdz.muli(epsilon);
        return new Pair((Object)dLdz, null);
    }

    public String toString() {
        return "elu(alpha=" + this.alpha + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ActivationELU)) {
            return false;
        }
        ActivationELU other = (ActivationELU)o;
        if (!other.canEqual(this)) {
            return false;
        }
        return Double.compare(this.getAlpha(), other.getAlpha()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof ActivationELU;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $alpha = Double.doubleToLongBits(this.getAlpha());
        result = result * 59 + (int)($alpha >>> 32 ^ $alpha);
        return result;
    }

    public double getAlpha() {
        return this.alpha;
    }
}

