package io.improbable.keanu.model.regression;

import com.google.common.base.Preconditions;
import io.improbable.keanu.model.ModelGraph;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:io/improbable/keanu/model/regression/LinearRegressionGraph.class */
public class LinearRegressionGraph<OUTPUT> implements ModelGraph<DoubleTensor, OUTPUT> {
    private final DoubleVertex xVertex;
    private final Vertex<OUTPUT> yVertex;
    private final Vertex<OUTPUT> yObservationVertex;
    private final DoubleVertex weightsVertex;
    private final DoubleVertex interceptVertex;
    private final BayesianNetwork bayesianNetwork;

    /* loaded from: input_file:io/improbable/keanu/model/regression/LinearRegressionGraph$OutputVertices.class */
    public static final class OutputVertices<OUTPUT> {
        private final Vertex<OUTPUT> outputVertex;
        private final Vertex<OUTPUT> observedVertex;

        public OutputVertices(Vertex<OUTPUT> vertex, Vertex<OUTPUT> vertex2) {
            this.outputVertex = vertex;
            this.observedVertex = vertex2;
        }

        public Vertex<OUTPUT> getOutputVertex() {
            return this.outputVertex;
        }

        public Vertex<OUTPUT> getObservedVertex() {
            return this.observedVertex;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof OutputVertices)) {
                return false;
            }
            OutputVertices outputVertices = (OutputVertices) obj;
            Vertex<OUTPUT> outputVertex = getOutputVertex();
            Vertex<OUTPUT> outputVertex2 = outputVertices.getOutputVertex();
            if (outputVertex == null) {
                if (outputVertex2 != null) {
                    return false;
                }
            } else if (!outputVertex.equals(outputVertex2)) {
                return false;
            }
            Vertex<OUTPUT> observedVertex = getObservedVertex();
            Vertex<OUTPUT> observedVertex2 = outputVertices.getObservedVertex();
            return observedVertex == null ? observedVertex2 == null : observedVertex.equals(observedVertex2);
        }

        public int hashCode() {
            Vertex<OUTPUT> outputVertex = getOutputVertex();
            int hashCode = (1 * 59) + (outputVertex == null ? 43 : outputVertex.hashCode());
            Vertex<OUTPUT> observedVertex = getObservedVertex();
            return (hashCode * 59) + (observedVertex == null ? 43 : observedVertex.hashCode());
        }

        public String toString() {
            return "LinearRegressionGraph.OutputVertices(outputVertex=" + getOutputVertex() + ", observedVertex=" + getObservedVertex() + ")";
        }
    }

    public LinearRegressionGraph(long[] jArr, Function<DoubleVertex, OutputVertices<OUTPUT>> function, DoubleVertex doubleVertex, DoubleVertex doubleVertex2) {
        long j = jArr[1];
        Preconditions.checkArgument(TensorShape.isLengthOne(doubleVertex.getShape()));
        TensorShapeValidation.checkShapesMatch(doubleVertex2.getShape(), new long[]{j, 1});
        this.weightsVertex = doubleVertex2;
        this.interceptVertex = doubleVertex;
        this.xVertex = new ConstantDoubleVertex(DoubleTensor.zeros(jArr));
        OutputVertices<OUTPUT> apply = function.apply(TensorShape.isLengthOne(doubleVertex2.getShape()) ? doubleVertex2.times(this.xVertex).plus(doubleVertex) : this.xVertex.matrixMultiply(doubleVertex2).plus(doubleVertex));
        this.yVertex = ((OutputVertices) apply).outputVertex;
        this.yObservationVertex = ((OutputVertices) apply).observedVertex;
        this.bayesianNetwork = new BayesianNetwork((Set<? extends Vertex>) this.yVertex.getConnectedGraph());
    }

    public OUTPUT predict(DoubleTensor doubleTensor) {
        this.xVertex.setAndCascade((DoubleVertex) doubleTensor);
        return this.yVertex.getValue();
    }

    /* renamed from: observeValues, reason: avoid collision after fix types in other method */
    public void observeValues2(DoubleTensor doubleTensor, OUTPUT output) {
        this.xVertex.setValue((DoubleVertex) doubleTensor);
        this.yObservationVertex.observe(output);
    }

    public DoubleVertex getInterceptVertex() {
        return this.interceptVertex;
    }

    public DoubleVertex getWeightVertex() {
        return this.weightsVertex;
    }

    public Vertex<OUTPUT> getOutputVertex() {
        return this.yObservationVertex;
    }

    @Override // io.improbable.keanu.model.ModelGraph
    public BayesianNetwork getBayesianNetwork() {
        return this.bayesianNetwork;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // io.improbable.keanu.model.ModelGraph
    public /* bridge */ /* synthetic */ void observeValues(DoubleTensor doubleTensor, Object obj) {
        observeValues2(doubleTensor, (DoubleTensor) obj);
    }
}
