8000 Issue/655: Standard deviation and variance by allmywatts · Pull Request #753 · rust-ndarray/ndarray · GitHub
[go: up one dir, main page]

Skip to content

Issue/655: Standard deviation and variance #753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use var in var_axis
  • Loading branch information
LukeMathWalker committed Dec 9, 2019
commit 0c8662d0b2813e34ee6f53eab179a7170d5db3da
24 changes: 6 additions & 18 deletions src/numeric/impl_numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,25 +235,13 @@ where
A: Float + FromPrimitive,
D: RemoveAxis,
{
let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
assert!(
!(ddof < zero || ddof > n),
"`ddof` must not be less than zero or greater than the length of \
the axis",
);
let dof = n - ddof;
let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis));
let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
for (i, subview) in self.axis_iter(axis).enumerate() {
let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
let delta = x - *mean;
*mean = *mean + delta / count;
*sum_sq = (x - *mean).mul_add(delta, *sum_sq);
let mut output = Array::zeros(self.dim.remove_axis(axis));
Zip::from(output.view_mut())
.and(self.lanes(axis))
.apply(|o, l| {
*o = l.var(ddof);
});
}
sum_sq.mapv_into(|s| s / dof)
output
}

/// Return standard deviation along `axis`.
Expand Down
0