/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.backward.standard;

import deepboof.backward.DFunctionDropOut;
import deepboof.impl.backward.standard.BaseDFunction;
import deepboof.misc.TensorOps_F64;
import deepboof.tensors.Tensor_F64;
import java.util.List;
import java.util.Random;

public class DFunctionDropOut_F64
extends BaseDFunction<Tensor_F64>
implements DFunctionDropOut<Tensor_F64> {
    Random random;
    double dropRate;
    Tensor_F64 drops = new Tensor_F64();

    public DFunctionDropOut_F64(long randomSeed, double dropRate) {
        this.random = new Random(randomSeed);
        this.dropRate = dropRate;
    }

    public void _initialize() {
        this.shapeOutput = (int[])this.shapeInput.clone();
    }

    public void _setParameters(List<Tensor_F64> parameters) {
    }

    public void _forward(Tensor_F64 input, Tensor_F64 output) {
        if (this.learningMode) {
            this.drops.reshape(input.shape);
            int N = this.drops.length();
            int indexIn = input.startIndex;
            int indexOut = output.startIndex;
            for (int i = 0; i < N; ++i) {
                this.drops.d[i] = this.random.nextDouble() < this.dropRate ? 0.0 : 1.0;
                double d = this.drops.d[i];
                output.d[indexOut++] = input.d[indexIn++] * d;
            }
        } else {
            TensorOps_F64.elementMult((Tensor_F64)input, (double)(1.0 - this.dropRate), (Tensor_F64)output);
        }
    }

    @Override
    public double getDropRate() {
        return this.dropRate;
    }

    @Override
    protected void _backwards(Tensor_F64 input, Tensor_F64 dout, Tensor_F64 gradientInput, List<Tensor_F64> gradientParameters) {
        TensorOps_F64.elementMult((Tensor_F64)dout, (Tensor_F64)this.drops, (Tensor_F64)gradientInput);
    }

    public Class<Tensor_F64> getTensorType() {
        return Tensor_F64.class;
    }
}

