/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.model.regression;

import io.improbable.keanu.model.ModelFitter;
import io.improbable.keanu.model.PredictiveModel;
import io.improbable.keanu.model.regression.LinearRegressionGraph;
import io.improbable.keanu.model.regression.RegressionModelBuilder;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.probabilistic.BernoulliVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.SigmoidVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.GaussianVertex;
import java.util.function.Function;

public class RegressionModel<OUTPUT>
implements PredictiveModel<DoubleTensor, OUTPUT> {
    private static final double DEFAULT_OBSERVATION_SIGMA = 1.0;
    private final ModelFitter fitter;
    private final LinearRegressionGraph<OUTPUT> modelGraph;

    RegressionModel(LinearRegressionGraph<OUTPUT> modelGraph, ModelFitter fitter) {
        this.modelGraph = modelGraph;
        this.fitter = fitter;
    }

    public static RegressionModelBuilder<DoubleTensor> withTrainingData(DoubleTensor inputTrainingData, DoubleTensor outputTrainingData) {
        return new RegressionModelBuilder<DoubleTensor>(inputTrainingData, outputTrainingData, RegressionModel.gaussianOutputTransform(1.0));
    }

    public static RegressionModelBuilder<BooleanTensor> withTrainingData(DoubleTensor inputTrainingData, BooleanTensor outputTrainingData) {
        return new RegressionModelBuilder<BooleanTensor>(inputTrainingData, outputTrainingData, RegressionModel.logisticOutputTransform());
    }

    static Function<DoubleVertex, LinearRegressionGraph.OutputVertices<DoubleTensor>> gaussianOutputTransform(double measurementSigma) {
        return yVertex -> new LinearRegressionGraph.OutputVertices<DoubleTensor>((Vertex<DoubleTensor>)yVertex, new GaussianVertex((DoubleVertex)yVertex, measurementSigma));
    }

    static Function<DoubleVertex, LinearRegressionGraph.OutputVertices<BooleanTensor>> logisticOutputTransform() {
        return probabilities -> {
            SigmoidVertex sigmoid = probabilities.sigmoid();
            return new LinearRegressionGraph.OutputVertices<BooleanTensor>(sigmoid.greaterThan(ConstantVertex.of(0.5)), new BernoulliVertex(sigmoid));
        };
    }

    public DoubleVertex getInterceptVertex() {
        return this.modelGraph.getInterceptVertex();
    }

    public DoubleVertex getWeightVertex() {
        return this.modelGraph.getWeightVertex();
    }

    public Vertex<OUTPUT> getOutputVertex() {
        return this.modelGraph.getOutputVertex();
    }

    @Override
    public OUTPUT predict(DoubleTensor tensor) {
        return this.modelGraph.predict(tensor);
    }

    public void fit() {
        this.fitter.fit(this.modelGraph);
    }
}

