8000 Fix using BLAS for all compatible cases of memory layout by bluss · Pull Request #1419 · rust-ndarray/ndarray · GitHub
[go: up one dir, main page]

Skip to content

Fix using BLAS for all compatible cases of memory layout #1419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 8, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
blas: Fix to skip array with too short stride
If we have a matrix of dimension say 5 x 5, BLAS requires the leading
stride to be >= 5. Smaller cases are possible for read-only array views
in ndarray(broadcasting and custom strides).

In this case we mark the array as not BLAS compatible
  • Loading branch information
bluss committed Aug 8, 2024
commit 01bb218ada456c80d937710ce2d3c997db96bb18
39 changes: 34 additions & 5 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ where

#[cfg(feature = "blas")]
#[derive(Copy, Clone)]
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
enum MemoryOrder
{
C,
Expand All @@ -887,24 +888,34 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
let (m, n) = dim.into_pattern();
let s0 = stride[0] as isize;
let s1 = stride[1] as isize;
let (inner_stride, outer_dim) = match order {
MemoryOrder::C => (s1, n),
MemoryOrder::F => (s0, m),
let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
MemoryOrder::C => (s1, s0, m, n),
MemoryOrder::F => (s0, s1, n, m),
};

if !(inner_stride == 1 || outer_dim == 1) {
return false;
}

if s0 < 1 || s1 < 1 {
return false;
}

if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
|| (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
{
return false;
}

// leading stride must >= the dimension (no broadcasting/aliasing)
if inner_dim > 1 && (outer_stride as usize) < outer_dim {
return false;
}

if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
return false;
}

true
}

Expand Down Expand Up @@ -1068,8 +1079,26 @@ mod blas_tests
}

#[test]
fn test()
fn blas_too_short_stride()
{
//WIP test that stride is larger than other dimension
// leading stride must be longer than the other dimension
// Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.

const N: usize = 5;
const MAXSTRIDE: usize = N + 2;
let mut data = [0; MAXSTRIDE * N];
let mut iter = 0..data.len();
data.fill_with(|| iter.next().unwrap());

for stride in 1..=MAXSTRIDE {
let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
eprintln!("{:?}", m);

if stride < N {
assert_eq!(get_blas_compatible_layout(&m), None);
} else {
assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C));
}
}
}
}
0