/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.model.regression;

import com.google.common.base.Preconditions;
import io.improbable.keanu.model.ModelGraph;
import io.improbable.keanu.network.BayesianNetwork;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import java.util.function.Function;

public class LinearRegressionGraph<OUTPUT>
implements ModelGraph<DoubleTensor, OUTPUT> {
    private final DoubleVertex xVertex;
    private final Vertex<OUTPUT> yVertex;
    private final Vertex<OUTPUT> yObservationVertex;
    private final DoubleVertex weightsVertex;
    private final DoubleVertex interceptVertex;
    private final BayesianNetwork bayesianNetwork;

    public LinearRegressionGraph(long[] featureShape, Function<DoubleVertex, OutputVertices<OUTPUT>> outputTransform, DoubleVertex interceptVertex, DoubleVertex weightsVertex) {
        long featureCount = featureShape[1];
        Preconditions.checkArgument((boolean)TensorShape.isLengthOne(interceptVertex.getShape()));
        TensorShapeValidation.checkShapesMatch(weightsVertex.getShape(), new long[]{featureCount, 1L});
        this.weightsVertex = weightsVertex;
        this.interceptVertex = interceptVertex;
        this.xVertex = new ConstantDoubleVertex(DoubleTensor.zeros(featureShape));
        OutputVertices<OUTPUT> outputVertices = outputTransform.apply(TensorShape.isLengthOne(weightsVertex.getShape()) ? weightsVertex.times(this.xVertex).plus(interceptVertex) : this.xVertex.matrixMultiply(weightsVertex).plus(interceptVertex));
        this.yVertex = ((OutputVertices)outputVertices).outputVertex;
        this.yObservationVertex = ((OutputVertices)outputVertices).observedVertex;
        this.bayesianNetwork = new BayesianNetwork(this.yVertex.getConnectedGraph());
    }

    public OUTPUT predict(DoubleTensor input) {
        this.xVertex.setAndCascade(input);
        return this.yVertex.getValue();
    }

    @Override
    public void observeValues(DoubleTensor input, OUTPUT output) {
        this.xVertex.setValue(input);
        this.yObservationVertex.observe(output);
    }

    public DoubleVertex getInterceptVertex() {
        return this.interceptVertex;
    }

    public DoubleVertex getWeightVertex() {
        return this.weightsVertex;
    }

    public Vertex<OUTPUT> getOutputVertex() {
        return this.yObservationVertex;
    }

    @Override
    public BayesianNetwork getBayesianNetwork() {
        return this.bayesianNetwork;
    }

    public static final class OutputVertices<OUTPUT> {
        private final Vertex<OUTPUT> outputVertex;
        private final Vertex<OUTPUT> observedVertex;

        public OutputVertices(Vertex<OUTPUT> outputVertex, Vertex<OUTPUT> observedVertex) {
            this.outputVertex = outputVertex;
            this.observedVertex = observedVertex;
        }

        public Vertex<OUTPUT> getOutputVertex() {
            return this.outputVertex;
        }

        public Vertex<OUTPUT> getObservedVertex() {
            return this.observedVertex;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof OutputVertices)) {
                return false;
            }
            OutputVertices other = (OutputVertices)o;
            Vertex<OUTPUT> this$outputVertex = this.getOutputVertex();
            Vertex<OUTPUT> other$outputVertex = other.getOutputVertex();
            if (this$outputVertex == null ? other$outputVertex != null : !((Object)this$outputVertex).equals(other$outputVertex)) {
                return false;
            }
            Vertex<OUTPUT> this$observedVertex = this.getObservedVertex();
            Vertex<OUTPUT> other$observedVertex = other.getObservedVertex();
            return !(this$observedVertex == null ? other$observedVertex != null : !((Object)this$observedVertex).equals(other$observedVertex));
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            Vertex<OUTPUT> $outputVertex = this.getOutputVertex();
            result = result * 59 + ($outputVertex == null ? 43 : ((Object)$outputVertex).hashCode());
            Vertex<OUTPUT> $observedVertex = this.getObservedVertex();
            result = result * 59 + ($observedVertex == null ? 43 : ((Object)$observedVertex).hashCode());
            return result;
        }

        public String toString() {
            return "LinearRegressionGraph.OutputVertices(outputVertex=" + this.getOutputVertex() + ", observedVertex=" + this.getObservedVertex() + ")";
        }
    }
}

