/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.norm;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;

public class Dropout
extends AbstractBlock {
    private static final byte VERSION = 2;
    private float probability;
    private int[] sharedAxes;

    Dropout(Builder builder) {
        super((byte)2);
        this.probability = builder.probability;
        this.sharedAxes = builder.sharedAxes;
    }

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArrayEx ex = inputs.singletonOrThrow().getNDArrayInternal();
        return ex.dropout(inputs, this.probability, this.sharedAxes, training, params);
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes) {
        return new Shape[]{inputShapes[0]};
    }

    @Override
    public void loadMetadata(byte version, DataInputStream is) throws IOException, MalformedModelException {
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private float probability = 0.5f;
        private int[] sharedAxes = new int[0];

        Builder() {
        }

        public Builder optProbability(float probability) {
            this.probability = probability;
            return this;
        }

        public Builder optSharedAxes(int[] sharedAxes) {
            this.sharedAxes = sharedAxes;
            return this;
        }

        public Dropout build() {
            return new Dropout(this);
        }
    }
}

