/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.activation;

import org.nd4j.linalg.api.activation.BaseActivationFunction;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.ElementWiseOp;
import org.nd4j.linalg.util.ComplexUtil;

public class HardTanh
extends BaseActivationFunction {
    private static final long serialVersionUID = -8484119406683594852L;

    @Override
    public Class<? extends ElementWiseOp> transformClazz() {
        return org.nd4j.linalg.ops.transforms.HardTanh.class;
    }

    @Override
    public String type() {
        return "hardtanh";
    }

    @Override
    public INDArray applyDerivative(INDArray input) {
        if (input instanceof IComplexNDArray) {
            IComplexNDArray n2 = (IComplexNDArray)input;
            IComplexNDArray n2Linear = n2.linearView();
            for (int i = 0; i < n2Linear.length(); ++i) {
                IComplexNumber val = n2Linear.getComplex(i);
                if (val.realComponent().doubleValue() < -1.0) {
                    val.set(-1, val.imaginaryComponent().doubleValue());
                } else if (val.realComponent().doubleValue() > 1.0) {
                    val.set(1, val.imaginaryComponent().doubleValue());
                } else {
                    val = Nd4j.createDouble(1.0, 0.0).subi(ComplexUtil.pow(ComplexUtil.tanh(val), 2.0));
                }
                n2Linear.putScalar(i, val);
            }
        } else {
            INDArray linear = input.linearView();
            for (int i = 0; i < linear.length(); ++i) {
                float val = linear.getFloat(i);
                val = val < -1.0f ? -1.0f : (val > 1.0f ? 1.0f : 1.0f - (float)Math.pow(Math.tanh(val), 2.0));
                linear.putScalar(i, val);
            }
        }
        return input;
    }
}

