package org.deeplearning4j.util;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/deeplearning4j/util/CapsuleUtils.class */
public class CapsuleUtils {
    public static SDVariable squash(SameDiff sameDiff, SDVariable sDVariable, int i) {
        SDVariable sum = sameDiff.math.square(sDVariable).sum(true, new int[]{i});
        return sDVariable.times(sum).div(sum.plus(1.0d).times(sameDiff.math.sqrt(sum.plus(1.0E-5d))));
    }

    public static SDVariable softmax(SameDiff sameDiff, SDVariable sDVariable, int i, int i2) {
        int[] range = ArrayUtil.range(0, i2);
        range[0] = i;
        range[i] = 0;
        return sameDiff.nn.softmax(sDVariable.permute(range)).permute(ArrayUtil.invertPermutation(range));
    }
}
