diff --git a/array_api_strict/linalg.py b/array_api_strict/linalg.py index b4b21c0..78e9ec4 100644 --- a/array_api_strict/linalg.py +++ b/array_api_strict/linalg.py @@ -73,9 +73,6 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: """ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in cross') - # Note: this is different from np.cross(), which broadcasts - if x1.shape != x2.shape: - raise ValueError('x1 and x2 must have the same shape') if x1.ndim == 0: raise ValueError('cross() requires arrays of dimension at least 1') # Note: this is different from np.cross(), which allows dimension 2