/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.op.core;

import java.util.Arrays;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.index.Indices;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.BroadcastTo;
import org.tensorflow.op.core.Concat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.ReduceProd;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Shape;
import org.tensorflow.op.core.StridedSlice;
import org.tensorflow.op.core.StridedSliceHelper;
import org.tensorflow.op.core.TensorScatterNdUpdate;
import org.tensorflow.op.core.Where;
import org.tensorflow.types.TBool;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.family.TType;

public abstract class BooleanMaskUpdate {
    public static <T extends TType> Operand<T> create(Scope scope, Operand<T> tensor, Operand<TBool> mask, Operand<T> updates, Options ... options) {
        scope = scope.withNameAsSubScope("BooleanMaskUpdate");
        int axis = 0;
        boolean broadcast = true;
        if (options != null) {
            for (Options opts : options) {
                if (opts.axis != null) {
                    axis = opts.axis;
                }
                if (opts.broadcast == null) continue;
                broadcast = opts.broadcast;
            }
        }
        if (axis < 0) {
            axis += tensor.rank();
        }
        org.tensorflow.ndarray.Shape maskShape = mask.shape();
        org.tensorflow.ndarray.Shape tensorShape = tensor.shape();
        if (maskShape.numDimensions() == 0) {
            throw new IllegalArgumentException("Mask cannot be a scalar.");
        }
        if (maskShape.hasUnknownDimension()) {
            throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
        }
        org.tensorflow.ndarray.Shape requiredMaskShape = tensorShape.subShape(axis, axis + maskShape.numDimensions());
        if (!requiredMaskShape.isCompatibleWith(maskShape)) {
            throw new IllegalArgumentException("Mask shape " + maskShape + " is not compatible with the required mask shape: " + requiredMaskShape + ".");
        }
        Shape<TInt32> liveShape = Shape.create(scope, tensor);
        ReduceProd<TInt32> leadingSize = ReduceProd.create(scope, StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceTo(axis + maskShape.numDimensions())), Constant.arrayOf(scope, 0), new ReduceProd.Options[0]);
        StridedSlice<TInt32> innerShape = StridedSliceHelper.stridedSlice(scope, liveShape, Indices.sliceFrom(axis + maskShape.numDimensions()));
        Reshape<T> reshaped = Reshape.create(scope, tensor, Concat.create(scope, Arrays.asList(Reshape.create(scope, leadingSize, Constant.arrayOf(scope, 1)), innerShape), Constant.scalarOf(scope, 0)));
        Where indices = Where.create(scope, mask);
        if (broadcast) {
            Shape<TInt32> indicesShape = Shape.create(scope, indices);
            StridedSlice<TInt32> batchShape = StridedSliceHelper.stridedSlice(scope, indicesShape, Indices.sliceTo(-1L));
            Concat updateShape = Concat.create(scope, Arrays.asList(batchShape, innerShape), Constant.scalarOf(scope, 0));
            updates = BroadcastTo.create(scope, updates, updateShape);
        }
        TensorScatterNdUpdate<T> newValue = TensorScatterNdUpdate.create(scope, reshaped, indices, updates);
        return Reshape.create(scope, newValue, liveShape);
    }

    public static Options axis(Integer axis) {
        return new Options().axis(axis);
    }

    public static Options broadcast(Boolean broadcast) {
        return new Options().broadcast(broadcast);
    }

    public static class Options {
        private Integer axis;
        private Boolean broadcast;

        public Options axis(Integer axis) {
            this.axis = axis;
            return this;
        }

        public Options broadcast(Boolean broadcast) {
            this.broadcast = broadcast;
            return this;
        }

        private Options() {
        }
    }
}

