/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.tensor;

import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;

public class BivariateDataStatisticsCalculator {
    private final DoubleTensor xData;
    private final DoubleTensor yData;

    public BivariateDataStatisticsCalculator(DoubleTensor xData, DoubleTensor yData) {
        TensorShapeValidation.checkAllShapesMatch(xData.getShape(), yData.getShape());
        this.xData = xData;
        this.yData = yData;
    }

    public long size() {
        return this.xData.getLength();
    }

    public double xMean() {
        return (Double)this.xData.average();
    }

    public double yMean() {
        return (Double)this.yData.average();
    }

    public double estimatedGradient() {
        return this.secondMomentOf(this.xData, this.yData) / this.secondMomentOf(this.xData);
    }

    public double estimatedIntercept() {
        return this.yMean() - this.estimatedGradient() * this.xMean();
    }

    public double meanSquaredError() {
        DoubleTensor calculatedY = (DoubleTensor)((Object)this.xData.times(this.estimatedGradient()).plusInPlace(this.estimatedIntercept()));
        DoubleTensor residuals = this.yData.minus(calculatedY);
        long unbiasedMultiplier = this.size() - 2L;
        return (Double)residuals.times(residuals).sum() / (double)unbiasedMultiplier;
    }

    public double standardErrorForGradient() {
        return Math.sqrt(this.meanSquaredError() / this.secondMomentOf(this.xData));
    }

    public double standardErrorForIntercept() {
        double value = this.xMean() * this.xMean() / this.secondMomentOf(this.xData);
        value += 1.0 / (double)this.size();
        return Math.sqrt(value *= this.meanSquaredError());
    }

    private double secondMomentOf(DoubleTensor data) {
        return this.secondMomentOf(data, data);
    }

    private double secondMomentOf(DoubleTensor data1, DoubleTensor data2) {
        double sum1 = (Double)data1.sum();
        double sum2 = (Double)data2.sum();
        double sumOfSquares = (Double)data1.times(data2).sum();
        return sumOfSquares - sum1 * sum2 / (double)this.size();
    }
}

