package io.improbable.keanu.model.regression;

import io.improbable.keanu.model.ModelFitter;
import io.improbable.keanu.model.SamplingModelFitting;
import io.improbable.keanu.model.regression.LinearRegressionGraph;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import java.util.function.Function;

/* loaded from: input_file:io/improbable/keanu/model/regression/RegressionModelBuilder.class */
public class RegressionModelBuilder<OUTPUT extends Tensor> {
    private static final double DEFAULT_MU = 0.0d;
    private static final double DEFAULT_SCALE_PARAMETER = 1.0d;
    private final DoubleTensor inputTrainingData;
    private final OUTPUT outputTrainingData;
    private final Function<DoubleVertex, LinearRegressionGraph.OutputVertices<OUTPUT>> outputTransform;
    private DoubleVertex priorOnWeightsScaleParameters;
    private DoubleVertex priorOnWeightsMeans;
    private DoubleVertex priorOnInterceptScaleParameter;
    private DoubleVertex priorOnInterceptMean;
    private RegressionRegularization regularization = RegressionRegularization.NONE;
    private SamplingModelFitting samplingAlgorithm = null;

    public RegressionModelBuilder(DoubleTensor doubleTensor, OUTPUT output, Function<DoubleVertex, LinearRegressionGraph.OutputVertices<OUTPUT>> function) {
        this.inputTrainingData = (DoubleTensor) reshapeToMatrix(doubleTensor);
        this.outputTrainingData = (OUTPUT) reshapeToMatrix(output);
        this.outputTransform = function;
    }

    public RegressionModelBuilder withRegularization(RegressionRegularization regressionRegularization) {
        this.regularization = regressionRegularization;
        return this;
    }

    private static <T extends Tensor> T reshapeToMatrix(T t) {
        return t.getRank() == 0 ? (T) t.reshape(1, 1) : t.getRank() == 1 ? (T) t.reshape(1, t.getShape()[0]) : t;
    }

    public RegressionModelBuilder withPriorOnWeights(DoubleVertex doubleVertex, DoubleVertex doubleVertex2) {
        this.priorOnWeightsMeans = doubleVertex;
        this.priorOnWeightsScaleParameters = doubleVertex2;
        return this;
    }

    public RegressionModelBuilder withPriorOnWeights(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return withPriorOnWeights(ConstantVertex.of(doubleTensor), ConstantVertex.of(doubleTensor2));
    }

    public RegressionModelBuilder withPriorOnWeights(double d, double d2) {
        return withPriorOnWeights(DoubleTensor.create(new double[]{d}, 1, 1), DoubleTensor.create(new double[]{d2}, 1, 1));
    }

    public RegressionModelBuilder withPriorOnIntercept(DoubleVertex doubleVertex, DoubleVertex doubleVertex2) {
        this.priorOnInterceptMean = doubleVertex;
        this.priorOnInterceptScaleParameter = doubleVertex2;
        return this;
    }

    public RegressionModelBuilder withPriorOnIntercept(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return withPriorOnIntercept(ConstantVertex.of(doubleTensor), ConstantVertex.of(doubleTensor2));
    }

    public RegressionModelBuilder withPriorOnIntercept(double d, double d2) {
        return withPriorOnIntercept(ConstantVertex.of(d), ConstantVertex.of(d2));
    }

    public RegressionModelBuilder withPriorOnWeightsAndIntercept(double d, double d2) {
        withPriorOnWeights(d, d2);
        withPriorOnIntercept(d, d2);
        return this;
    }

    public RegressionModelBuilder withSampling(SamplingModelFitting samplingModelFitting) {
        this.samplingAlgorithm = samplingModelFitting;
        return this;
    }

    public RegressionModel<OUTPUT> build() {
        RegressionModel<OUTPUT> buildWithoutFitting = buildWithoutFitting();
        buildWithoutFitting.fit();
        return buildWithoutFitting;
    }

    public RegressionModel<OUTPUT> buildWithoutFitting() {
        checkVariablesAreCorrectlyInitialised();
        LinearRegressionGraph linearRegressionGraph = new LinearRegressionGraph(this.inputTrainingData.getShape(), this.outputTransform, getInterceptVertex(), getWeightsVertex());
        ModelFitter createFitterForGraph = this.samplingAlgorithm == null ? this.regularization.createFitterForGraph() : this.samplingAlgorithm.createFitterForGraph();
        linearRegressionGraph.observeValues2(this.inputTrainingData, (DoubleTensor) this.outputTrainingData);
        return new RegressionModel<>(linearRegressionGraph, createFitterForGraph);
    }

    private void checkVariablesAreCorrectlyInitialised() {
        if (this.inputTrainingData == null) {
            throw new IllegalArgumentException("You have not provided input training data");
        }
        if (this.outputTrainingData == null) {
            throw new IllegalArgumentException("You have not provided output training data");
        }
        if (this.priorOnWeightsMeans == null || this.priorOnWeightsScaleParameters == null) {
            withPriorOnWeights(DEFAULT_MU, DEFAULT_SCALE_PARAMETER);
        }
        if (this.priorOnInterceptMean == null || this.priorOnInterceptScaleParameter == null) {
            withPriorOnIntercept(DEFAULT_MU, DEFAULT_SCALE_PARAMETER);
        }
    }

    private DoubleVertex getInterceptVertex() {
        return this.regularization.getInterceptVertex(this.priorOnInterceptMean, this.priorOnInterceptScaleParameter);
    }

    private DoubleVertex getWeightsVertex() {
        return this.regularization.getWeightsVertex(getFeatureCount(), this.priorOnWeightsMeans, this.priorOnWeightsScaleParameters);
    }

    private long getFeatureCount() {
        return this.inputTrainingData.getShape()[1];
    }
}
