package org.nd4j.linalg.api.ops.impl.accum;

import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BaseAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/accum/Variance.class */
public class Variance extends BaseAccumulation {
    protected double mean;
    protected double bias;
    protected boolean biasCorrected;

    public Variance() {
        this.biasCorrected = true;
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, long j) {
        super(iNDArray, iNDArray2, iNDArray3, j);
        this.biasCorrected = true;
        init(iNDArray, iNDArray2, iNDArray3, j);
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, long j) {
        this(iNDArray, iNDArray2, iNDArray, j);
    }

    public Variance(INDArray iNDArray) {
        this(iNDArray, null, iNDArray, iNDArray.lengthLong(), true);
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2) {
        super(iNDArray, iNDArray2);
        this.biasCorrected = true;
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, long j, boolean z) {
        super(iNDArray, iNDArray2, iNDArray3, j);
        this.biasCorrected = true;
        this.biasCorrected = z;
        init(iNDArray, iNDArray2, iNDArray3, j);
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, long j, boolean z) {
        super(iNDArray, iNDArray2, j);
        this.biasCorrected = true;
        this.biasCorrected = z;
        init(iNDArray, iNDArray2, this.z, j);
    }

    public Variance(INDArray iNDArray, boolean z) {
        super(iNDArray);
        this.biasCorrected = true;
        this.biasCorrected = z;
        init(iNDArray, this.y, this.z, this.n);
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        super(iNDArray, iNDArray2);
        this.biasCorrected = true;
        this.biasCorrected = z;
        init(iNDArray, iNDArray2, iNDArray, iNDArray.lengthLong());
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public INDArray noOp() {
        return Nd4j.zerosLike(x());
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Op
    public double op(double d) {
        return d - this.mean;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Op
    public float op(float f) {
        return (float) (f - this.mean);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public double update(double d, double d2) {
        return d + (d2 * d2);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public double update(double d, double d2, double d3) {
        return d + (d2 * d2);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public float update(float f, float f2) {
        return f + (f2 * f2);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public float update(float f, float f2, float f3) {
        return f + (f2 * f2);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, double d) {
        double d2 = d - this.mean;
        return iComplexNumber.add(Double.valueOf(d2 * d2));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, double d, double d2) {
        double d3 = d - this.mean;
        return iComplexNumber.add(Double.valueOf(d3 * d3));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2) {
        IComplexNumber sub = iComplexNumber2.sub(Double.valueOf(this.mean));
        return iComplexNumber.add(sub.mul(sub));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2, IComplexNumber iComplexNumber3) {
        IComplexNumber sub = iComplexNumber2.sub(Double.valueOf(this.mean));
        return iComplexNumber.add(sub.mul(sub));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2, double d) {
        IComplexNumber sub = iComplexNumber2.sub(Double.valueOf(this.mean));
        return iComplexNumber.add(sub.mul(sub));
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public int opNum() {
        return 0;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public String name() {
        return "var";
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Op opForDimension(int i, int i2) {
        Variance variance = y() != null ? new Variance(this.x.vectorAlongDimension(i, i2), this.y.vectorAlongDimension(i, i2), r0.length()) : new Variance(this.x.vectorAlongDimension(i, i2));
        variance.setBiasCorrected(this.biasCorrected);
        variance.setApplyFinalTransform(applyFinalTransform());
        return variance;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Variance opForDimension(int i, int... iArr) {
        Variance variance = y() != null ? new Variance(this.x.tensorAlongDimension(i, iArr), this.y.tensorAlongDimension(i, iArr), r0.length()) : new Variance(this.x.tensorAlongDimension(i, iArr), this.biasCorrected);
        variance.setApplyFinalTransform(applyFinalTransform());
        return variance;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void init(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, long j) {
        super.init(iNDArray, iNDArray2, iNDArray3, j);
        if (Nd4j.executionMode == OpExecutioner.ExecutionMode.JAVA) {
            if (this.biasCorrected) {
                this.bias = Nd4j.getExecutioner().execAndReturn((Accumulation) new Bias(iNDArray)).getFinalResult().doubleValue();
            }
            this.mean = Nd4j.getExecutioner().execAndReturn((Accumulation) new Mean(iNDArray)).getFinalResult().doubleValue();
        }
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public boolean isPassThrough() {
        return true;
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec() {
        if (this.biasCorrected) {
            this.bias = Nd4j.getExecutioner().execAndReturn((Accumulation) new Bias(this.x)).getFinalResult().doubleValue();
        }
        this.mean = Nd4j.getExecutioner().execAndReturn((Accumulation) new Mean(this.x)).getFinalResult().doubleValue();
        INDArray sub = this.x.sub(Double.valueOf(this.mean));
        getAndSetFinalResult(Nd4j.getExecutioner().execAndReturn((Accumulation) new Sum(sub.muli(sub))).getFinalResult().doubleValue());
        this.z = Nd4j.scalar(this.finalResult);
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec(int... iArr) {
        if (iArr.length == 1 && iArr[0] == Integer.MAX_VALUE) {
            exec();
            return;
        }
        int[] removeIndex = ArrayUtil.removeIndex(this.x.shape(), iArr);
        int tensorssAlongDimension = this.x.tensorssAlongDimension(iArr);
        this.z = Nd4j.create(removeIndex);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            this.z.putScalar(i, Nd4j.getExecutioner().execAndReturn((Accumulation) opForDimension(i, iArr)).getFinalResult().doubleValue());
        }
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public double combineSubResults(double d, double d2) {
        return d + d2;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public float combineSubResults(float f, float f2) {
        return f + f2;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber combineSubResults(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2) {
        return iComplexNumber.add(iComplexNumber2);
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public double getAndSetFinalResult(double d) {
        double pow = this.biasCorrected ? (d - (FastMath.pow(this.bias, 2.0d) / n())) / (n() - 1.0d) : d / n();
        this.finalResult = Double.valueOf(pow);
        return pow;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public float getAndSetFinalResult(float f) {
        return (float) getAndSetFinalResult(f);
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber getAndSetFinalResult(IComplexNumber iComplexNumber) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public double calculateFinalResult(double d, long j) {
        double pow = this.biasCorrected ? (d - (FastMath.pow(this.bias, 2.0d) / j)) / (j - 1.0d) : d / j;
        this.finalResult = Double.valueOf(pow);
        return pow;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public float calculateFinalResult(float f, long j) {
        return (float) calculateFinalResult(f, j);
    }

    public boolean isBiasCorrected() {
        return this.biasCorrected;
    }

    public void setBiasCorrected(boolean z) {
        this.biasCorrected = z;
    }
}
