/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.model.regression;

import io.improbable.keanu.model.ModelFitter;
import io.improbable.keanu.model.SamplingModelFitting;
import io.improbable.keanu.model.regression.LinearRegressionGraph;
import io.improbable.keanu.model.regression.RegressionModel;
import io.improbable.keanu.model.regression.RegressionRegularization;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import java.util.function.Function;

public class RegressionModelBuilder<OUTPUT extends Tensor> {
    private static final double DEFAULT_MU = 0.0;
    private static final double DEFAULT_SCALE_PARAMETER = 1.0;
    private final DoubleTensor inputTrainingData;
    private final OUTPUT outputTrainingData;
    private final Function<DoubleVertex, LinearRegressionGraph.OutputVertices<OUTPUT>> outputTransform;
    private RegressionRegularization regularization = RegressionRegularization.NONE;
    private DoubleVertex priorOnWeightsScaleParameters;
    private DoubleVertex priorOnWeightsMeans;
    private DoubleVertex priorOnInterceptScaleParameter;
    private DoubleVertex priorOnInterceptMean;
    private SamplingModelFitting samplingAlgorithm = null;

    public RegressionModelBuilder(DoubleTensor inputTrainingData, OUTPUT outputTrainingData, Function<DoubleVertex, LinearRegressionGraph.OutputVertices<OUTPUT>> outputTransform) {
        this.inputTrainingData = RegressionModelBuilder.reshapeToMatrix(inputTrainingData);
        this.outputTrainingData = RegressionModelBuilder.reshapeToMatrix(outputTrainingData);
        this.outputTransform = outputTransform;
    }

    public RegressionModelBuilder withRegularization(RegressionRegularization regularization) {
        this.regularization = regularization;
        return this;
    }

    private static <T extends Tensor> T reshapeToMatrix(T data) {
        if (data.getRank() == 0) {
            return data.reshape(1L, 1L);
        }
        if (data.getRank() == 1) {
            return data.reshape(1L, data.getShape()[0]);
        }
        return data;
    }

    public RegressionModelBuilder withPriorOnWeights(DoubleVertex means, DoubleVertex scaleParameters) {
        this.priorOnWeightsMeans = means;
        this.priorOnWeightsScaleParameters = scaleParameters;
        return this;
    }

    public RegressionModelBuilder withPriorOnWeights(DoubleTensor means, DoubleTensor scaleParameters) {
        return this.withPriorOnWeights(ConstantVertex.of(means), ConstantVertex.of(scaleParameters));
    }

    public RegressionModelBuilder withPriorOnWeights(double means, double scaleParameters) {
        return this.withPriorOnWeights(DoubleTensor.create(new double[]{means}, 1L, 1L), DoubleTensor.create(new double[]{scaleParameters}, 1L, 1L));
    }

    public RegressionModelBuilder withPriorOnIntercept(DoubleVertex mean, DoubleVertex scaleParameter) {
        this.priorOnInterceptMean = mean;
        this.priorOnInterceptScaleParameter = scaleParameter;
        return this;
    }

    public RegressionModelBuilder withPriorOnIntercept(DoubleTensor mean, DoubleTensor scaleParameter) {
        return this.withPriorOnIntercept(ConstantVertex.of(mean), ConstantVertex.of(scaleParameter));
    }

    public RegressionModelBuilder withPriorOnIntercept(double mean, double scaleParameter) {
        return this.withPriorOnIntercept(ConstantVertex.of(mean), ConstantVertex.of(scaleParameter));
    }

    public RegressionModelBuilder withPriorOnWeightsAndIntercept(double mean, double scaleParameter) {
        this.withPriorOnWeights(mean, scaleParameter);
        this.withPriorOnIntercept(mean, scaleParameter);
        return this;
    }

    public RegressionModelBuilder withSampling(SamplingModelFitting sampling) {
        this.samplingAlgorithm = sampling;
        return this;
    }

    public RegressionModel<OUTPUT> build() {
        RegressionModel<OUTPUT> model = this.buildWithoutFitting();
        model.fit();
        return model;
    }

    public RegressionModel<OUTPUT> buildWithoutFitting() {
        this.checkVariablesAreCorrectlyInitialised();
        LinearRegressionGraph<OUTPUT> regressionGraph = new LinearRegressionGraph<OUTPUT>(this.inputTrainingData.getShape(), this.outputTransform, this.getInterceptVertex(), this.getWeightsVertex());
        ModelFitter fitter = this.samplingAlgorithm == null ? this.regularization.createFitterForGraph() : this.samplingAlgorithm.createFitterForGraph();
        regressionGraph.observeValues(this.inputTrainingData, this.outputTrainingData);
        return new RegressionModel<OUTPUT>(regressionGraph, fitter);
    }

    private void checkVariablesAreCorrectlyInitialised() {
        if (this.inputTrainingData == null) {
            throw new IllegalArgumentException("You have not provided input training data");
        }
        if (this.outputTrainingData == null) {
            throw new IllegalArgumentException("You have not provided output training data");
        }
        if (this.priorOnWeightsMeans == null || this.priorOnWeightsScaleParameters == null) {
            this.withPriorOnWeights(0.0, 1.0);
        }
        if (this.priorOnInterceptMean == null || this.priorOnInterceptScaleParameter == null) {
            this.withPriorOnIntercept(0.0, 1.0);
        }
    }

    private DoubleVertex getInterceptVertex() {
        return this.regularization.getInterceptVertex(this.priorOnInterceptMean, this.priorOnInterceptScaleParameter);
    }

    private DoubleVertex getWeightsVertex() {
        return this.regularization.getWeightsVertex(this.getFeatureCount(), this.priorOnWeightsMeans, this.priorOnWeightsScaleParameters);
    }

    private long getFeatureCount() {
        return this.inputTrainingData.getShape()[1];
    }
}

