package io.improbable.keanu.model.regression;

import io.improbable.keanu.model.ModelFitter;
import io.improbable.keanu.model.PredictiveModel;
import io.improbable.keanu.model.regression.LinearRegressionGraph;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.probabilistic.BernoulliVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.SigmoidVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.GaussianVertex;
import java.util.function.Function;

/* loaded from: input_file:io/improbable/keanu/model/regression/RegressionModel.class */
public class RegressionModel<OUTPUT> implements PredictiveModel<DoubleTensor, OUTPUT> {
    private static final double DEFAULT_OBSERVATION_SIGMA = 1.0d;
    private final ModelFitter fitter;
    private final LinearRegressionGraph<OUTPUT> modelGraph;

    /* JADX INFO: Access modifiers changed from: package-private */
    public RegressionModel(LinearRegressionGraph<OUTPUT> linearRegressionGraph, ModelFitter modelFitter) {
        this.modelGraph = linearRegressionGraph;
        this.fitter = modelFitter;
    }

    public static RegressionModelBuilder<DoubleTensor> withTrainingData(DoubleTensor doubleTensor, DoubleTensor doubleTensor2) {
        return new RegressionModelBuilder<>(doubleTensor, doubleTensor2, gaussianOutputTransform(DEFAULT_OBSERVATION_SIGMA));
    }

    public static RegressionModelBuilder<BooleanTensor> withTrainingData(DoubleTensor doubleTensor, BooleanTensor booleanTensor) {
        return new RegressionModelBuilder<>(doubleTensor, booleanTensor, logisticOutputTransform());
    }

    static Function<DoubleVertex, LinearRegressionGraph.OutputVertices<DoubleTensor>> gaussianOutputTransform(double d) {
        return doubleVertex -> {
            return new LinearRegressionGraph.OutputVertices(doubleVertex, new GaussianVertex(doubleVertex, d));
        };
    }

    static Function<DoubleVertex, LinearRegressionGraph.OutputVertices<BooleanTensor>> logisticOutputTransform() {
        return doubleVertex -> {
            SigmoidVertex sigmoid = doubleVertex.sigmoid();
            return new LinearRegressionGraph.OutputVertices(sigmoid.greaterThan(ConstantVertex.of(0.5d)), new BernoulliVertex(sigmoid));
        };
    }

    public DoubleVertex getInterceptVertex() {
        return this.modelGraph.getInterceptVertex();
    }

    public DoubleVertex getWeightVertex() {
        return this.modelGraph.getWeightVertex();
    }

    public Vertex<OUTPUT> getOutputVertex() {
        return this.modelGraph.getOutputVertex();
    }

    @Override // io.improbable.keanu.model.PredictiveModel
    public OUTPUT predict(DoubleTensor doubleTensor) {
        return this.modelGraph.predict(doubleTensor);
    }

    public void fit() {
        this.fitter.fit(this.modelGraph);
    }
}
