/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.IsMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;

public class MaskedReductionUtil {
    private static final int[] CNN_DIM_MASK_H = new int[]{0, 2};
    private static final int[] CNN_DIM_MASK_W = new int[]{0, 3};

    public static INDArray maskedPoolingTimeSeries(PoolingType poolingType, INDArray toReduce, INDArray mask, int pnorm) {
        if (toReduce.rank() != 3) {
            throw new IllegalArgumentException("Expect rank 3 array: got " + toReduce.rank());
        }
        if (mask.rank() != 2) {
            throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank());
        }
        switch (poolingType) {
            case MAX: {
                INDArray negInfMask = Transforms.not((INDArray)mask);
                BooleanIndexing.replaceWhere((INDArray)negInfMask, (Number)Double.NEGATIVE_INFINITY, (Condition)Conditions.equals((Number)1.0));
                INDArray withInf = Nd4j.createUninitialized((int[])toReduce.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastAddOp(toReduce, negInfMask, withInf, new int[]{0, 2}));
                return withInf.max(new int[]{2});
            }
            case AVG: 
            case SUM: {
                INDArray masked = Nd4j.createUninitialized((int[])toReduce.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(toReduce, mask, masked, new int[]{0, 2}));
                INDArray summed = masked.sum(new int[]{2});
                if (poolingType == PoolingType.SUM) {
                    return summed;
                }
                INDArray maskCounts = mask.sum(new int[]{1});
                summed.diviColumnVector(maskCounts);
                return summed;
            }
            case PNORM: {
                INDArray masked2 = Nd4j.createUninitialized((int[])toReduce.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(toReduce, mask, masked2, new int[]{0, 2}));
                INDArray abs = Transforms.abs((INDArray)masked2, (boolean)true);
                Transforms.pow((INDArray)abs, (Number)pnorm, (boolean)false);
                INDArray pNorm = abs.sum(new int[]{2});
                return Transforms.pow((INDArray)pNorm, (Number)(1.0 / (double)pnorm));
            }
            case NONE: {
                throw new UnsupportedOperationException("NONE pooling type not supported");
            }
        }
        throw new UnsupportedOperationException("Unknown or not supported pooling type: " + (Object)((Object)poolingType));
    }

    public static INDArray maskedPoolingEpsilonTimeSeries(PoolingType poolingType, INDArray input, INDArray mask, INDArray epsilon2d, int pnorm) {
        if (input.rank() != 3) {
            throw new IllegalArgumentException("Expect rank 3 input activation array: got " + input.rank());
        }
        if (mask.rank() != 2) {
            throw new IllegalArgumentException("Expect rank 2 array for mask: got " + mask.rank());
        }
        if (epsilon2d.rank() != 2) {
            throw new IllegalArgumentException("Expected rank 2 array for errors: got " + epsilon2d.rank());
        }
        switch (poolingType) {
            case MAX: {
                INDArray negInfMask = Transforms.not((INDArray)mask);
                BooleanIndexing.replaceWhere((INDArray)negInfMask, (Number)Double.NEGATIVE_INFINITY, (Condition)Conditions.equals((Number)1.0));
                INDArray withInf = Nd4j.createUninitialized((int[])input.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastAddOp(input, negInfMask, withInf, new int[]{0, 2}));
                INDArray isMax = Nd4j.getExecutioner().execAndReturn((TransformOp)new IsMax(withInf, new int[]{2}));
                return Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(isMax, epsilon2d, isMax, new int[]{0, 1}));
            }
            case AVG: 
            case SUM: {
                INDArray out = Nd4j.createUninitialized((int[])input.shape(), (char)'f');
                Nd4j.getExecutioner().exec((Op)new BroadcastCopyOp(out, epsilon2d, out, new int[]{0, 1}));
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(out, mask, out, new int[]{0, 2}));
                if (poolingType == PoolingType.SUM) {
                    return out;
                }
                INDArray nEachTimeSeries = mask.sum(new int[]{1});
                Nd4j.getExecutioner().exec((Op)new BroadcastDivOp(out, nEachTimeSeries, out, new int[]{0}));
                return out;
            }
            case PNORM: {
                INDArray numerator;
                INDArray masked2 = Nd4j.createUninitialized((int[])input.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(input, mask, masked2, new int[]{0, 2}));
                INDArray abs = Transforms.abs((INDArray)masked2, (boolean)true);
                Transforms.pow((INDArray)abs, (Number)pnorm, (boolean)false);
                INDArray pNorm = Transforms.pow((INDArray)abs.sum(new int[]{2}), (Number)(1.0 / (double)pnorm));
                if (pnorm == 2) {
                    numerator = input.dup();
                } else {
                    INDArray absp2 = Transforms.pow((INDArray)Transforms.abs((INDArray)input, (boolean)true), (Number)(pnorm - 2), (boolean)false);
                    numerator = input.mul(absp2);
                }
                INDArray denom = Transforms.pow((INDArray)pNorm, (Number)(pnorm - 1), (boolean)false);
                denom.rdivi(epsilon2d);
                Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(numerator, denom, numerator, new int[]{0, 1}));
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(numerator, mask, numerator, new int[]{0, 2}));
                return numerator;
            }
            case NONE: {
                throw new UnsupportedOperationException("NONE pooling type not supported");
            }
        }
        throw new UnsupportedOperationException("Unknown or not supported pooling type: " + (Object)((Object)poolingType));
    }

    public static INDArray maskedPoolingConvolution(PoolingType poolingType, INDArray toReduce, INDArray mask, boolean alongHeight, int pnorm) {
        int[] dimensions = alongHeight ? CNN_DIM_MASK_H : CNN_DIM_MASK_W;
        switch (poolingType) {
            case MAX: {
                INDArray negInfMask = Transforms.not((INDArray)mask);
                BooleanIndexing.replaceWhere((INDArray)negInfMask, (Number)Double.NEGATIVE_INFINITY, (Condition)Conditions.equals((Number)1.0));
                INDArray withInf = Nd4j.createUninitialized((int[])toReduce.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastAddOp(toReduce, negInfMask, withInf, dimensions));
                return withInf.max(new int[]{2, 3});
            }
            case AVG: 
            case SUM: {
                INDArray masked = Nd4j.createUninitialized((int[])toReduce.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(toReduce, mask, masked, dimensions));
                INDArray summed = masked.sum(new int[]{2, 3});
                if (poolingType == PoolingType.SUM) {
                    return summed;
                }
                INDArray maskCounts = mask.sum(new int[]{1});
                summed.diviColumnVector(maskCounts);
                return summed;
            }
            case PNORM: {
                INDArray masked2 = Nd4j.createUninitialized((int[])toReduce.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(toReduce, mask, masked2, dimensions));
                INDArray abs = Transforms.abs((INDArray)masked2, (boolean)true);
                Transforms.pow((INDArray)abs, (Number)pnorm, (boolean)false);
                INDArray pNorm = abs.sum(new int[]{2, 3});
                return Transforms.pow((INDArray)pNorm, (Number)(1.0 / (double)pnorm));
            }
            case NONE: {
                throw new UnsupportedOperationException("NONE pooling type not supported");
            }
        }
        throw new UnsupportedOperationException("Unknown or not supported pooling type: " + (Object)((Object)poolingType));
    }

    public static INDArray maskedPoolingEpsilonCnn(PoolingType poolingType, INDArray input, INDArray mask, INDArray epsilon2d, boolean alongHeight, int pnorm) {
        int[] dimensions = alongHeight ? CNN_DIM_MASK_H : CNN_DIM_MASK_W;
        switch (poolingType) {
            case MAX: {
                INDArray negInfMask = Transforms.not((INDArray)mask);
                BooleanIndexing.replaceWhere((INDArray)negInfMask, (Number)Double.NEGATIVE_INFINITY, (Condition)Conditions.equals((Number)1.0));
                INDArray withInf = Nd4j.createUninitialized((int[])input.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastAddOp(input, negInfMask, withInf, dimensions));
                INDArray isMax = Nd4j.getExecutioner().execAndReturn((TransformOp)new IsMax(withInf, new int[]{2, 3}));
                return Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(isMax, epsilon2d, isMax, new int[]{0, 1}));
            }
            case AVG: 
            case SUM: {
                INDArray out = Nd4j.createUninitialized((int[])input.shape(), (char)'f');
                Nd4j.getExecutioner().exec((Op)new BroadcastCopyOp(out, epsilon2d, out, new int[]{0, 1}));
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(out, mask, out, dimensions));
                if (poolingType == PoolingType.SUM) {
                    return out;
                }
                INDArray nEachTimeSeries = mask.sum(new int[]{1});
                Nd4j.getExecutioner().exec((Op)new BroadcastDivOp(out, nEachTimeSeries, out, new int[]{0}));
                return out;
            }
            case PNORM: {
                INDArray numerator;
                INDArray masked2 = Nd4j.createUninitialized((int[])input.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(input, mask, masked2, dimensions));
                INDArray abs = Transforms.abs((INDArray)masked2, (boolean)true);
                Transforms.pow((INDArray)abs, (Number)pnorm, (boolean)false);
                INDArray pNorm = Transforms.pow((INDArray)abs.sum(new int[]{2, 3}), (Number)(1.0 / (double)pnorm));
                if (pnorm == 2) {
                    numerator = input.dup();
                } else {
                    INDArray absp2 = Transforms.pow((INDArray)Transforms.abs((INDArray)input, (boolean)true), (Number)(pnorm - 2), (boolean)false);
                    numerator = input.mul(absp2);
                }
                INDArray denom = Transforms.pow((INDArray)pNorm, (Number)(pnorm - 1), (boolean)false);
                denom.rdivi(epsilon2d);
                Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(numerator, denom, numerator, new int[]{0, 1}));
                Nd4j.getExecutioner().exec((Op)new BroadcastMulOp(numerator, mask, numerator, dimensions));
                return numerator;
            }
            case NONE: {
                throw new UnsupportedOperationException("NONE pooling type not supported");
            }
        }
        throw new UnsupportedOperationException("Unknown or not supported pooling type: " + (Object)((Object)poolingType));
    }
}

