diff --git a/src/interpolator/one/strategies.rs b/src/interpolator/one/strategies.rs index 808c506..c3df870 100644 --- a/src/interpolator/one/strategies.rs +++ b/src/interpolator/one/strategies.rs @@ -1,4 +1,5 @@ use super::*; +use strategy::cubic::*; use strategy::*; impl Strategy1D for Linear @@ -37,6 +38,143 @@ where } } +impl Strategy1D for Cubic +where + D: Data + RawDataClone + Clone, + D::Elem: Float + Euclid + Debug, +{ + fn init(&mut self, data: &InterpData1D) -> Result<(), ValidateError> { + // Number of segments + let n = data.grid[0].len() - 1; + + let zero = D::Elem::zero(); + let one = D::Elem::one(); + let two = ::from(2.).unwrap(); + let six = ::from(6.).unwrap(); + + let h = Array1::from_shape_fn(n, |i| data.grid[0][i + 1] - data.grid[0][i]); + let v = Array1::from_shape_fn(n - 1, |i| two * (h[i + 1] + h[i])); + let b = Array1::from_shape_fn(n, |i| (data.values[i + 1] - data.values[i]) / h[i]); + let u = Array1::from_shape_fn(n - 1, |i| six * (b[i + 1] - b[i])); + + let (sub, diag, sup, rhs) = match &self.boundary_condition { + CubicBC::Natural => { + let zero = array![zero]; + let one = array![one]; + ( + &ndarray::concatenate(Axis(0), &[h.slice(s![0..n - 1]), zero.view()]).unwrap(), + &ndarray::concatenate(Axis(0), &[one.view(), v.view(), one.view()]).unwrap(), + &ndarray::concatenate(Axis(0), &[zero.view(), h.slice(s![1..n])]).unwrap(), + &ndarray::concatenate(Axis(0), &[zero.view(), u.view(), zero.view()]).unwrap(), + ) + } + CubicBC::Clamped(l, r) => { + let diag_0 = array![two * h[0]]; + let diag_n = array![two * h[n - 1]]; + let rhs_0 = array![six * (b[0] - *l)]; + let rhs_n = array![six * (*r - b[n - 1])]; + ( + &h, + &ndarray::concatenate(Axis(0), &[diag_0.view(), v.view(), diag_n.view()]) + .unwrap(), + &h, + &ndarray::concatenate(Axis(0), &[rhs_0.view(), u.view(), rhs_n.view()]) + .unwrap(), + ) + } + CubicBC::NotAKnot => { + let three = two + one; + let sub_n = + array![two * h[n - 1].powi(2) + three * h[n - 1] * h[n - 2] + h[n - 2].powi(2)]; + let diag_0 = array![h[0].powi(2) - h[1].powi(2)]; + let diag_n = array![h[n - 1].powi(2) - h[n - 2].powi(2)]; + let sup_0 = array![two * h[0].powi(2) + three * h[0] * h[1] + h[1].powi(2)]; + let rhs_0 = array![h[0] * u[0]]; + let rhs_n = array![h[n - 1] * u[n - 2]]; + + println!( + "sub {:?}", + &ndarray::concatenate(Axis(0), &[h.slice(s![0..n - 1]), sub_n.view()]).unwrap() + ); + println!( + "dia {:?}", + &ndarray::concatenate(Axis(0), &[diag_0.view(), v.view(), diag_n.view()]) + .unwrap() + ); + println!( + "sup {:?}", + &ndarray::concatenate(Axis(0), &[sup_0.view(), h.slice(s![1..n])]).unwrap() + ); + println!( + "rhs {:?}", + &ndarray::concatenate(Axis(0), &[rhs_0.view(), u.view(), rhs_n.view()]) + .unwrap() + ); + ( + &ndarray::concatenate(Axis(0), &[h.slice(s![0..n - 1]), sub_n.view()]).unwrap(), + &ndarray::concatenate(Axis(0), &[diag_0.view(), v.view(), diag_n.view()]) + .unwrap(), + &ndarray::concatenate(Axis(0), &[sup_0.view(), h.slice(s![1..n])]).unwrap(), + &ndarray::concatenate(Axis(0), &[rhs_0.view(), u.view(), rhs_n.view()]) + .unwrap(), + ) + } + _ => unreachable!(), + }; + + self.z = Self::thomas(sub.view(), diag.view(), sup.view(), rhs.view()).into_dyn(); + + Ok(()) + } + + fn interpolate( + &self, + data: &InterpData1D, + point: &[::Elem; 1], + ) -> Result<::Elem, InterpolateError> { + let last = data.grid[0].len() - 1; + let l = if point[0] < data.grid[0][0] { + match &self.extrapolate { + CubicExtrapolate::Linear => { + let h0 = data.grid[0][1] - data.grid[0][0]; + let k0 = (data.values[1] - data.values[0]) / h0 + - h0 * self.z[1] / ::from(6.).unwrap(); + return Ok(k0 * (point[0] - data.grid[0][0]) + data.values[0]); + } + CubicExtrapolate::Spline => 0, + CubicExtrapolate::Wrap => { + let point = [wrap(point[0], data.grid[0][0], data.grid[0][last])]; + let l = find_nearest_index(data.grid[0].view(), &point[0]); + return self.evaluate_1d(&point, l, data); + } + } + } else if point[0] > data.grid[0][last] { + match &self.extrapolate { + CubicExtrapolate::Linear => { + let hn = data.grid[0][last] - data.grid[0][last - 1]; + let kn = (data.values[last] - data.values[last - 1]) / hn + + hn * self.z[last - 1] / ::from(6.).unwrap(); + return Ok(kn * (point[0] - data.grid[0][last]) + data.values[last]); + } + CubicExtrapolate::Spline => last - 1, + CubicExtrapolate::Wrap => { + let point = [wrap(point[0], data.grid[0][0], data.grid[0][last])]; + let l = find_nearest_index(data.grid[0].view(), &point[0]); + return self.evaluate_1d(&point, l, data); + } + } + } else { + find_nearest_index(data.grid[0].view(), &point[0]) + }; + self.evaluate_1d(point, l, data) + } + + /// Returns `true` + fn allow_extrapolate(&self) -> bool { + true + } +} + impl Strategy1D for Nearest where D: Data + RawDataClone + Clone, diff --git a/src/interpolator/one/tests.rs b/src/interpolator/one/tests.rs index 1638df0..b6c6dac 100644 --- a/src/interpolator/one/tests.rs +++ b/src/interpolator/one/tests.rs @@ -162,3 +162,234 @@ fn test_extrapolate() { assert_approx_eq!(interp.interpolate(&[-0.75]).unwrap(), 0.05); assert_eq!(interp.interpolate(&[5.]).unwrap(), 1.2); } + +#[test] +fn test_cubic_natural() { + let x = array![1., 2.4, 3.1, 5., 7.6, 8.3, 10., 10.1]; + let f_x = array![3., -90., 19., 99., 291., 444., 222., 250.]; + + let interp = Interp1D::new( + x.view(), + f_x.view(), + strategy::Cubic::natural(), + Extrapolate::Enable, + ) + .unwrap(); + + // Interpolating at knots returns values + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); + } + + let x0 = x.first().unwrap(); + let xn = x.last().unwrap(); + let y0 = f_x.first().unwrap(); + let yn = f_x.last().unwrap(); + + let range = xn - x0; + + let x_low = x0 - 0.2 * range; + let y_low = interp.interpolate(&[x_low]).unwrap(); + let slope_low = (y0 - y_low) / (x0 - x_low); + + let x_high = xn + 0.2 * range; + let y_high = interp.interpolate(&[x_high]).unwrap(); + let slope_high = (y_high - yn) / (x_high - xn); + + let xs_left = Array1::linspace(*x0, x0 + 2e-6, 50); + let xs_right = Array1::linspace(xn - 2e-6, *xn, 50); + + // Left extrapolation is linear + let ys: Array1 = xs_left + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_left + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), slope_low); + + // Right extrapolation is linear + let ys: Array1 = xs_right + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_right + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), slope_high); +} + +#[test] +fn test_cubic_clamped() { + let x = array![1., 2.4, 3.1, 5., 7.6, 8.3, 10., 10.1]; + let f_x = array![3., -90., 19., 99., 291., 444., 222., 250.]; + + let xs_left = Array1::linspace(x.first().unwrap() - 1e-6, x.first().unwrap() + 1e-6, 50); + let xs_right = Array1::linspace(x.last().unwrap() - 1e-6, x.last().unwrap() + 1e-6, 50); + + for (a, b) in [(-5., 10.), (0., 0.), (2.4, -5.2)] { + let interp = Interp1D::new( + x.view(), + f_x.view(), + strategy::Cubic::clamped(a, b), + Extrapolate::Enable, + ) + .unwrap(); + + // Interpolating at knots returns values + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); + } + + // Left slope = a + let ys: Array1 = xs_left + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_left + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), a); + + // Right slope = b + let ys: Array1 = xs_right + .iter() + .map(|&x| interp.interpolate(&[x]).unwrap()) + .collect(); + let slopes: Array1 = xs_right + .windows(2) + .into_iter() + .zip(ys.windows(2)) + .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + .collect(); + assert_approx_eq!(slopes.mean().unwrap(), b); + } +} + +#[test] +fn test_cubic_not_a_knot() { + let x = array![1., 2.4, 3.1, 5., 7.6, 8.3, 10., 10.1]; + let f_x = array![3., -90., 19., 99., 291., 444., 222., 250.]; + + let x = array![1., 2., 3., 5., 7., 8., 10.]; + let f_x = array![3., -90., 19., 99., 291., 444., 222.]; + + let interp = Interp1D::new( + x.view(), + f_x.view(), + strategy::Cubic::not_a_knot(), + Extrapolate::Enable, + ) + .unwrap(); + + // Interpolating at knots returns values + for i in 0..x.len() { + assert_approx_eq!(interp.interpolate(&[x[i]]).unwrap(), f_x[i]); + } + + // // Left slope = a + // let ys: Array1 = xs_left + // .iter() + // .map(|&x| interp.interpolate(&[x]).unwrap()) + // .collect(); + // let slopes: Array1 = xs_left + // .windows(2) + // .into_iter() + // .zip(ys.windows(2)) + // .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + // .collect(); + // assert_approx_eq!(slopes.mean().unwrap(), a); + + // // Right slope = b + // let ys: Array1 = xs_right + // .iter() + // .map(|&x| interp.interpolate(&[x]).unwrap()) + // .collect(); + // let slopes: Array1 = xs_right + // .windows(2) + // .into_iter() + // .zip(ys.windows(2)) + // .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) + // .collect(); + // assert_approx_eq!(slopes.mean().unwrap(), b); +} + +// #[test] +// fn test_cubic_periodic() { +// let x = array![1., 2., 3., 5., 7., 8.]; +// let f_x = array![3., -90., 19., 99., 291., 444.]; +// +// let interp_extrap_enable = +// Interp1D::new(x.view(), f_x.view(), strategy::Cubic::periodic(), Extrapolate::Enable).unwrap(); +// let interp_extrap_wrap = +// Interp1D::new(x.view(), f_x.view(), strategy::Cubic::periodic(), Extrapolate::Wrap).unwrap(); +// +// // Interpolating at knots returns values +// for i in 0..x.len() { +// assert_approx_eq!(interp_extrap_enable.interpolate(&[x[i]]).unwrap(), f_x[i]); +// assert_approx_eq!(interp_extrap_wrap.interpolate(&[x[i]]).unwrap(), f_x[i]); +// } +// +// // Extrapolate::Enable is equivalent to Extrapolate::Wrap for Cubic::periodic() +// let x0 = x.first().unwrap(); +// let xn = x.last().unwrap(); +// let range = xn - x0; +// let x_low = x0 - 0.2 * range; +// let x_high = x0 + 0.2 * range; +// let xs_left = Array1::linspace(x_low, *x0, 50); +// let xs_right = Array1::linspace(*xn, x_high, 50); +// for x in xs_left { +// assert_eq!( +// interp_extrap_enable.interpolate(&[x]).unwrap(), +// interp_extrap_wrap.interpolate(&[x]).unwrap() +// ); +// } +// for x in xs_right { +// assert_eq!( +// interp_extrap_enable.interpolate(&[x]).unwrap(), +// interp_extrap_wrap.interpolate(&[x]).unwrap() +// ); +// } +// +// // Slope left +// let xs_left = Array1::linspace(x_low, x_low + 2e6, 50); +// let ys_left: Array1 = xs_left +// .iter() +// .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap()) +// .collect(); +// let slopes_left: Array1 = xs_left +// .windows(2) +// .into_iter() +// .zip(ys_left.windows(2)) +// .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) +// .collect(); +// let slope_left = slopes_left.mean().unwrap(); +// // Slope right +// let xs_right = Array1::linspace(x_high - 2e6, x_high, 50); +// let ys_right: Array1 = xs_right +// .iter() +// .map(|&x| interp_extrap_enable.interpolate(&[x]).unwrap()) +// .collect(); +// let slopes_right: Array1 = xs_right +// .windows(2) +// .into_iter() +// .zip(ys_right.windows(2)) +// .map(|(x, y)| (y[1] - y[0]) / (x[1] - x[0])) +// .collect(); +// let slope_right = slopes_right.mean().unwrap(); +// // Slopes at left and right are equal +// assert_approx_eq!(slope_left, slope_right); +// // Second derivatives at left and right are equal +// let z = interp_extrap_enable.strategy.z; +// assert_approx_eq!(z.first().unwrap(), z.last().unwrap()); +// } diff --git a/src/lib.rs b/src/lib.rs index 19f1dbd..279bc90 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -148,7 +148,7 @@ pub use ndarray; pub(crate) use ndarray::prelude::*; pub(crate) use ndarray::{Data, Ix, RawDataClone}; -pub(crate) use num_traits::{clamp, Euclid, Num, One}; +pub(crate) use num_traits::{clamp, Euclid, Float, Num, NumCast, One, Zero}; pub(crate) use dyn_clone::*; diff --git a/src/strategy/cubic.rs b/src/strategy/cubic.rs new file mode 100644 index 0000000..5cf6675 --- /dev/null +++ b/src/strategy/cubic.rs @@ -0,0 +1,187 @@ +use super::*; + +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +pub struct Cubic { + /// Cubic spline boundary conditions. + pub boundary_condition: CubicBC, + /// Behavior of [`Extrapolate::Enable`]. + pub extrapolate: CubicExtrapolate, + /// Solved second derivatives. + pub z: ArrayD, +} + +/// Cubic spline boundary conditions. +#[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +pub enum CubicBC { + Natural, + Clamped(T, T), + NotAKnot, + // https://math.ou.edu/~npetrov/project-5093-s11.pdf + Periodic, +} + +/// [`Extrapolate::Enable`] behavior for cubic splines +#[derive(Copy, Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))] +pub enum CubicExtrapolate { + /// Linear extrapolation, default for natural splines. + Linear, + /// Use nearest spline to extrapolate. + Spline, + /// Same as [`Extrapolate::Wrap`], default for periodic splines. + Wrap, +} + +impl Cubic { + /// Cubic spline with given boundary condition and extrapolation behavior. + pub fn new(boundary_condition: CubicBC, extrapolate: CubicExtrapolate) -> Self { + Self { + boundary_condition, + extrapolate, + z: Array1::from_vec(Vec::new()).into_dyn(), + } + } + + /// Natural cubic spline + /// (splines straighten at outermost knots). + /// + /// 2nd derivatives at outermost knots are zero: + /// z0 = zn = 0 + /// + /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Linear`]. + pub fn natural() -> Self { + Self::new(CubicBC::Natural, CubicExtrapolate::Linear) + } + + /// Clamped cubic spline. + /// + /// 1st derivatives at outermost knots (k0, kn) are given. + /// + /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Spline`]. + pub fn clamped(k0: T, kn: T) -> Self { + Self::new(CubicBC::Clamped(k0, kn), CubicExtrapolate::Spline) + } + + /// Not-a-knot cubic spline. + /// + /// Spline 3rd derivatives at second and second-to-last knots are continuous, respectively: + /// S'''0(x1) = S'''1(x1) and + /// S'''n-1(xn-1) = S'''n(xn-1). + /// + /// In other words, this means the first and second spline at data boundaries are the same cubic. + /// + /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Spline`]. + pub fn not_a_knot() -> Self { + Self::new(CubicBC::NotAKnot, CubicExtrapolate::Spline) + } + + /// Periodic cubic spline. + /// + /// Spline 1st & 2nd derivatives at outermost knots are equal: + /// k0 = kn, z0 = zn + /// + /// [`Extrapolate::Enable`] defaults to [`CubicExtrapolate::Wrap`], + /// thus is equivalent to [`Extrapolate::Wrap`]. + pub fn periodic() -> Self { + Self::new(CubicBC::Periodic, CubicExtrapolate::Wrap) + } +} + +impl Cubic +where + T: Float + Debug, +{ + // Reference: https://www.math.ntnu.no/emner/TMA4215/2008h/cubicsplines.pdf + pub(crate) fn evaluate_1d + RawDataClone + Clone>( + &self, + point: &[T; 1], + l: usize, + data: &InterpData1D, + ) -> Result { + let six = ::from(6.).unwrap(); + let u = l + 1; + let h_i = data.grid[0][u] - data.grid[0][l]; + Ok( + self.z[u] / (six * h_i) * (point[0] - data.grid[0][l]).powi(3) + + self.z[l] / (six * h_i) * (data.grid[0][u] - point[0]).powi(3) + + (data.values[u] / h_i - self.z[u] * h_i / six) * (point[0] - data.grid[0][l]) + + (data.values[l] / h_i - self.z[l] * h_i / six) * (data.grid[0][u] - point[0]), + ) + } + + pub(crate) fn evaluate_2d + RawDataClone + Clone>( + &self, + point: &[T; 2], + l: usize, + data: &InterpData2D, + ) -> Result { + todo!() + } + + pub(crate) fn evaluate_3d + RawDataClone + Clone>( + &self, + point: &[T; 3], + l: usize, + data: &InterpData3D, + ) -> Result { + todo!() + } + + pub(crate) fn evaluate_nd + RawDataClone + Clone>( + &self, + point: &[T], + l: usize, + data: &InterpDataND, + ) -> Result { + todo!() + } +} + +impl Cubic +where + T: Num + Copy, +{ + /// Solves Ax = d for a tridiagonal matrix A using the + /// [Thomas algorithm](https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm). + /// - `a`: sub-diagonal (1 element shorter than `b` and `d`) + /// - `b`: diagonal + /// - `c`: super-diagonal (1 element shorter than `b` and `d`) + /// - `d`: right-hand side + pub(crate) fn thomas( + a: ArrayView1, + b: ArrayView1, + c: ArrayView1, + d: ArrayView1, + ) -> Array1 { + let n = d.len(); + assert_eq!(a.len(), n - 1); + assert_eq!(b.len(), n); + assert_eq!(c.len(), n - 1); + + let mut c_prime = Array1::zeros(n - 1); + let mut d_prime = Array1::zeros(n); + let mut x = Array1::zeros(n); + + // Forward sweep + c_prime[0] = c[0] / b[0]; + d_prime[0] = d[0] / b[0]; + + for i in 1..n - 1 { + let denom = b[i] - a[i - 1] * c_prime[i - 1]; + c_prime[i] = c[i] / denom; + d_prime[i] = (d[i] - a[i - 1] * d_prime[i - 1]) / denom; + } + d_prime[n - 1] = + (d[n - 1] - a[n - 2] * d_prime[n - 2]) / (b[n - 1] - a[n - 2] * c_prime[n - 2]); + + // Back substitution + x[n - 1] = d_prime[n - 1]; + for i in (0..n - 1).rev() { + x[i] = d_prime[i] - c_prime[i] * x[i + 1]; + } + + x + } +} diff --git a/src/strategy/mod.rs b/src/strategy/mod.rs index 898f1c8..de51dd9 100644 --- a/src/strategy/mod.rs +++ b/src/strategy/mod.rs @@ -2,8 +2,11 @@ use super::*; +pub mod cubic; pub mod traits; +pub use cubic::Cubic; + /// Linear interpolation: #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]