public class TensorShapeValidation
extends java.lang.Object
| Modifier and Type | Method and Description |
|---|---|
static long[] |
checkAllShapesMatch(java.util.Collection<long[]> shapes) |
static long[] |
checkAllShapesMatch(long[]... shapes) |
static long[] |
checkAllShapesMatch(java.lang.String errorMessage,
java.util.Collection<long[]> shapes) |
static long[] |
checkAllShapesMatch(java.lang.String errorMessage,
long[]... shapes) |
static void |
checkDimensionExistsInShape(int dimension,
long[] shape)
Check if the given dimension exists within the shape
|
static long[] |
checkHasOneNonLengthOneShapeOrAllLengthOne(long[]... shapes)
This ensures there is at most a single non length one shape.
|
static void |
checkIndexIsValid(long[] shape,
long... index) |
static long[] |
checkIsBroadcastable(long[] left,
long[] right) |
static void |
checkShapeIsSquareMatrix(long[] shape) |
static long[] |
checkShapesCanBeConcatenated(int dimension,
long[]... shapes) |
static void |
checkShapesMatch(long[] actual,
long[] expected) |
static void |
checkTensorsAreScalar(java.lang.String message,
long[]... shapes) |
static void |
checkTensorsMatchNonLengthOneShapeOrAreLengthOne(long[] proposalShape,
long[]... shapes)
This is a common function to check that tensors are either
the same shape of the proposal in question OR length one.
|
static long[] |
checkTernaryConditionShapeIsValid(long[] predicate,
long[] thn,
long[] els) |
static long[] |
getMatrixMultiplicationResultingShape(long[] left,
long[] right) |
static long[] |
getTensorMultiplyResultShape(long[] leftShape,
long[] rightShape,
int[] dimsLeft,
int[] dimsRight) |
static boolean |
isBroadcastable(long[] left,
long[] right) |
public static void checkTensorsMatchNonLengthOneShapeOrAreLengthOne(long[] proposalShape,
long[]... shapes)
proposalShape - the tensor shape being validatedshapes - the tensors being validated againstjava.lang.IllegalArgumentException - if there is more than one non length one shape OR if the non length one shape does
not match the proposal shape.public static void checkDimensionExistsInShape(int dimension,
long[] shape)
dimension - Proposed dimensionshape - Shape to checkjava.lang.IllegalArgumentException - if the dimension exceeds the rank of the shapepublic static void checkTensorsAreScalar(java.lang.String message,
long[]... shapes)
public static long[] checkHasOneNonLengthOneShapeOrAllLengthOne(long[]... shapes)
shapes - the tensors for shape checkingjava.lang.IllegalArgumentException - if there is more than one non length one shape or multiple ranks of length 1 shapespublic static long[] checkIsBroadcastable(long[] left,
long[] right)
public static boolean isBroadcastable(long[] left,
long[] right)
public static long[] checkTernaryConditionShapeIsValid(long[] predicate,
long[] thn,
long[] els)
predicate - shape of predicatethn - shape of thenels - shape of elsepublic static void checkShapeIsSquareMatrix(long[] shape)
public static void checkShapesMatch(long[] actual,
long[] expected)
public static long[] checkAllShapesMatch(long[]... shapes)
public static long[] checkAllShapesMatch(java.lang.String errorMessage,
long[]... shapes)
public static long[] checkAllShapesMatch(java.lang.String errorMessage,
java.util.Collection<long[]> shapes)
public static long[] checkAllShapesMatch(java.util.Collection<long[]> shapes)
public static long[] checkShapesCanBeConcatenated(int dimension,
long[]... shapes)
public static void checkIndexIsValid(long[] shape,
long... index)
public static long[] getTensorMultiplyResultShape(long[] leftShape,
long[] rightShape,
int[] dimsLeft,
int[] dimsRight)
public static long[] getMatrixMultiplicationResultingShape(long[] left,
long[] right)