/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.activations.impl;

import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarMin;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.factory.Nd4j;

public class ActivationReLU
extends BaseActivationFunction {
    private Double max;
    private Double threshold;
    private Double negativeSlope;

    public ActivationReLU() {
        this(null, null, null);
    }

    public ActivationReLU(Double maxValue, Double threshold, Double negativeSlope) {
        this.max = maxValue;
        this.threshold = threshold;
        this.negativeSlope = negativeSlope;
    }

    @Override
    public INDArray getActivation(INDArray in, boolean training) {
        if (this.negativeSlope != null || this.threshold != null) {
            double ns;
            double t = this.threshold == null ? 0.0 : this.threshold;
            double d = ns = this.negativeSlope == null ? 0.0 : this.negativeSlope;
            if (t == 0.0) {
                Nd4j.getExecutioner().execAndReturn(new LeakyReLU(in, ns));
            } else {
                INDArray oneGte = in.gte(t).castTo(in.dataType());
                INDArray oneLt = in.lt(t).castTo(in.dataType());
                INDArray lower = oneLt.muli(ns).muli(in.sub(this.threshold));
                INDArray upper = oneGte.muli(in);
                in.assign(lower.addi(upper));
            }
        } else {
            Nd4j.getExecutioner().exec(new RectifiedLinear(in, in));
        }
        if (this.max != null) {
            Nd4j.exec(new ScalarMin(in, null, in, this.max));
        }
        return in;
    }

    @Override
    public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
        INDArray dLdz;
        INDArray maxMask;
        this.assertShape(in, epsilon);
        INDArray iNDArray = maxMask = this.max == null || this.max == 0.0 ? null : in.lt(this.max);
        if (this.negativeSlope != null || this.threshold != null) {
            double ns;
            double t = this.threshold == null ? 0.0 : this.threshold;
            double d = ns = this.negativeSlope == null ? 0.0 : this.negativeSlope;
            if (t == 0.0) {
                dLdz = Nd4j.getExecutioner().exec(new LeakyReLUBp(in, epsilon, in.ulike(), ns))[0];
            } else {
                INDArray oneGte = in.gte(t).castTo(in.dataType());
                INDArray oneLt = in.lt(t).castTo(in.dataType());
                INDArray lower = oneLt.muli(ns);
                INDArray upper = oneGte;
                dLdz = in.assign(lower.addi(upper)).muli(epsilon);
            }
        } else {
            dLdz = Nd4j.getExecutioner().exec(new RectifiedLinearDerivative(in, epsilon, in.ulike(), this.threshold == null ? 0.0 : this.threshold))[0];
        }
        if (maxMask != null) {
            dLdz.muli(maxMask);
        }
        return new Pair((Object)dLdz, null);
    }

    public String toString() {
        return "relu";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ActivationReLU)) {
            return false;
        }
        ActivationReLU other = (ActivationReLU)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Double this$max = this.getMax();
        Double other$max = other.getMax();
        if (this$max == null ? other$max != null : !((Object)this$max).equals(other$max)) {
            return false;
        }
        Double this$threshold = this.getThreshold();
        Double other$threshold = other.getThreshold();
        if (this$threshold == null ? other$threshold != null : !((Object)this$threshold).equals(other$threshold)) {
            return false;
        }
        Double this$negativeSlope = this.getNegativeSlope();
        Double other$negativeSlope = other.getNegativeSlope();
        return !(this$negativeSlope == null ? other$negativeSlope != null : !((Object)this$negativeSlope).equals(other$negativeSlope));
    }

    protected boolean canEqual(Object other) {
        return other instanceof ActivationReLU;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Double $max = this.getMax();
        result = result * 59 + ($max == null ? 43 : ((Object)$max).hashCode());
        Double $threshold = this.getThreshold();
        result = result * 59 + ($threshold == null ? 43 : ((Object)$threshold).hashCode());
        Double $negativeSlope = this.getNegativeSlope();
        result = result * 59 + ($negativeSlope == null ? 43 : ((Object)$negativeSlope).hashCode());
        return result;
    }

    public Double getMax() {
        return this.max;
    }

    public Double getThreshold() {
        return this.threshold;
    }

    public Double getNegativeSlope() {
        return this.negativeSlope;
    }
}

