Class BaseFunction<T extends Tensor>

java.lang.Object
deepboof.impl.forward.standard.BaseFunction<T>
All Implemented Interfaces:
Function<T>
Direct Known Subclasses:
BaseSpatialWindow, ElementWiseFunction, FunctionBatchNorm_F32, FunctionBatchNorm_F64, FunctionLinear_F32, FunctionLinear_F64

public abstract class BaseFunction<T extends Tensor>
extends Object
implements Function<T>
Base class which implements common functionality between all functions
  • Field Details

    • shapeInput

      protected int[] shapeInput
    • shapeParameters

      protected List<int[]> shapeParameters
    • shapeOutput

      protected int[] shapeOutput
    • parameters

      protected List<T extends Tensor> parameters
    • miniBatchSize

      protected int miniBatchSize
      Number of inputs in the mini-batch
  • Constructor Details

    • BaseFunction

      public BaseFunction()
  • Method Details

    • initialize

      public void initialize​(int... shapeInput)
      Description copied from interface: Function
      Initializes internal data structures given the shape of the input tensor, minus the stacked input dimension. For example, an input tensor of shape (B,C,D) might be passed into initialize, while the actual input is (N,B,C,D). N is the number of stacked inputs and is allowed to vary after initialization.
      Specified by:
      initialize in interface Function<T extends Tensor>
      Parameters:
      shapeInput - Shape of the input tensor
    • _initialize

      public abstract void _initialize()
    • setParameters

      public void setParameters​(List<T> parameters)
      Description copied from interface: Function

      Specifies learnable function parameters, e.g. weights for linear functions. This function only needs to be called once each time a parameter has been modified. Must be called before Function.forward(T, T).

      NOTE: Reference to the parameters may be saved internally and the tensors should not be modified externally.
      Specified by:
      setParameters in interface Function<T extends Tensor>
      Parameters:
      parameters - Tensors containing parameters which are optimized. Not modified.
    • _setParameters

      public abstract void _setParameters​(List<T> parameters)
    • getParameters

      public List<T> getParameters()
      Description copied from interface: Function
      If the parameters have been set, then this returns the list of parameters. Otherwise null is returned.
      Specified by:
      getParameters in interface Function<T extends Tensor>
      Returns:
      List of parameters or null if they have not been set yet
    • forward

      public void forward​(T input, T output)
      Description copied from interface: Function
      Performs forward pass of the function on the provided inputs.
       Input tensor shape = (N,variable ... )
       - N is the mini-batch size
       - Other dimensions are implementation specific.
       
      Specified by:
      forward in interface Function<T extends Tensor>
      Parameters:
      input - Input to the function.
      output - Output tensor. Modified.
    • _forward

      public abstract void _forward​(T input, T output)
    • getParameterShapes

      public List<int[]> getParameterShapes()
      Description copied from interface: Function
      Returns the shape of input tensors, without the mini-batch dimension. Only valid after Function.initialize(int...) has been called.
      Specified by:
      getParameterShapes in interface Function<T extends Tensor>
      Returns:
      Expected shapes of input tensors. This data structure may be recycled and is modified on the next call to Function.initialize(int...).
    • getOutputShape

      public int[] getOutputShape()
      Description copied from interface: Function
      Returns the output tensor's shape, without the mini-batch dimension. Only valid after Function.initialize(int...) has been called.
      Specified by:
      getOutputShape in interface Function<T extends Tensor>
      Returns:
      Expected shape of output tensor. This data structure may be recycled and is modified on the next call to Function.initialize(int...).