package io.improbable.keanu.tensor;

import io.improbable.keanu.tensor.dbl.DoubleTensor;

/* loaded from: input_file:io/improbable/keanu/tensor/BivariateDataStatisticsCalculator.class */
public class BivariateDataStatisticsCalculator {
    private final DoubleTensor xData;
    private final DoubleTensor yData;

    /* JADX WARN: Type inference failed for: r0v2, types: [long[], long[][]] */
    public BivariateDataStatisticsCalculator(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        TensorShapeValidation.checkAllShapesMatch((long[][]) new long[]{doubleTensor.getShape(), doubleTensor2.getShape()});
        this.xData = doubleTensor;
        this.yData = doubleTensor2;
    }

    public long size() {
        return this.xData.getLength();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public double xMean() {
        return ((Double) this.xData.average()).doubleValue();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public double yMean() {
        return ((Double) this.yData.average()).doubleValue();
    }

    public double estimatedGradient() {
        return secondMomentOf(this.xData, this.yData) / secondMomentOf(this.xData);
    }

    public double estimatedIntercept() {
        return yMean() - (estimatedGradient() * xMean());
    }

    /* JADX WARN: Multi-variable type inference failed */
    public double meanSquaredError() {
        DoubleTensor doubleTensor = (DoubleTensor) this.yData.minus((DoubleTensor) this.xData.times2(estimatedGradient()).plusInPlace((DoubleTensor) Double.valueOf(estimatedIntercept())));
        return ((Double) ((DoubleTensor) doubleTensor.times(doubleTensor)).sum()).doubleValue() / (size() - 2);
    }

    public double standardErrorForGradient() {
        return Math.sqrt(meanSquaredError() / secondMomentOf(this.xData));
    }

    public double standardErrorForIntercept() {
        return Math.sqrt((((xMean() * xMean()) / secondMomentOf(this.xData)) + (1.0d / size())) * meanSquaredError());
    }

    private double secondMomentOf(DoubleTensor doubleTensor) {
        return secondMomentOf(doubleTensor, doubleTensor);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double secondMomentOf(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return ((Double) ((DoubleTensor) doubleTensor.times(doubleTensor2)).sum()).doubleValue() - ((((Double) doubleTensor.sum()).doubleValue() * ((Double) doubleTensor2.sum()).doubleValue()) / size());
    }
}
