public class LinearRegressionGraph<OUTPUT> extends java.lang.Object implements ModelGraph<DoubleTensor,OUTPUT>
| Modifier and Type | Class and Description |
|---|---|
static class |
LinearRegressionGraph.OutputVertices<OUTPUT> |
| Constructor and Description |
|---|
LinearRegressionGraph(long[] featureShape,
java.util.function.Function<DoubleVertex,LinearRegressionGraph.OutputVertices<OUTPUT>> outputTransform,
DoubleVertex interceptVertex,
DoubleVertex weightsVertex) |
| Modifier and Type | Method and Description |
|---|---|
DoubleVertex |
getInterceptVertex() |
Vertex<OUTPUT> |
getOutputVertex() |
DoubleVertex |
getWeightVertex() |
void |
observeValues(DoubleTensor input,
OUTPUT output) |
OUTPUT |
predict(DoubleTensor input) |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitgetBayesianNetworkpublic LinearRegressionGraph(long[] featureShape,
java.util.function.Function<DoubleVertex,LinearRegressionGraph.OutputVertices<OUTPUT>> outputTransform,
DoubleVertex interceptVertex,
DoubleVertex weightsVertex)
public OUTPUT predict(DoubleTensor input)
public void observeValues(DoubleTensor input, OUTPUT output)
observeValues in interface ModelGraph<DoubleTensor,OUTPUT>public DoubleVertex getInterceptVertex()
public DoubleVertex getWeightVertex()