Open
Description
Describe the bug
sklearn.neighbors.KernelDensity
supports automatic (optimal) bandwidth calculation via bandwidth = 'silverman'
and bandwidth = 'scott'
. The algorithm computes the appropriate observation-weighted bandwidth factors (proportional to nobs^0.2) but does not adjust for the standard deviation or interquartile range of the dataset. Roughly, the algorithm should scale the dataset's standard error by the algorithmic bandwidth factors.
See, e.g., Wikipedia. The implementation in scipy.stats._kde
is correct.
Steps/Code to Reproduce
import matplotlib.pyplot as plot
import numpy as np
from sklearn.neighbors import KernelDensity
from scipy.stats import gaussian_kde
data = np.random.normal( scale = 0.01, size = 100 )
#
# 1. sklearn (auto)
#
kd_sklearn_auto = KernelDensity( kernel = 'gaussian', bandwidth = 'silverman' )
kd_sklearn_auto.fit( np.reshape( data, ( -1, 1 ) ) )
#
# 2. sklearn (manual)
#
kd_sklearn_manual = KernelDensity( kernel = 'gaussian', bandwidth = 0.9 * np.std( data ) / len( data ) ** ( 1 / 5 ) )
kd_sklearn_manual.fit( np.reshape( data, ( -1, 1 ) ) )
#
# 3. scipy
#
kd_scipy = gaussian_kde( data, bw_method = 'silverman' )
#
# 4. show the difference
#
xs = np.arange( start = -0.05, stop = 0.05, step = 1e-4 )
plot.plot( xs, np.exp( kd_sklearn_auto.score_samples( np.reshape( xs, ( -1, 1 ) ) ) ), label = 'KDE SKLearn (auto)' )
plot.plot( xs, np.exp( kd_sklearn_manual.score_samples( np.reshape( xs, ( -1, 1 ) ) ) ), label = 'KDE SKLearn (manual)' )
plot.plot( xs, kd_scipy.pdf( xs ), label = 'KDE SciPy' )
plot.hist( data, label = 'Data' )
plot.legend()
plot.show()
Expected Results
Automatic SKLearn bandwidth curve should approximately match SciPy bandwidth curve, roughly the shape of the underlying data histogram.
Actual Results
Automatic SKLearn bandwidth curve generates a flat PDF.
Versions
System:
python: 3.9.5 (default, Nov 23 2021, 15:27:38) [GCC 9.3.0]
executable: /local_disk0/.ephemeral_nfs/envs/pythonEnv-67c47d19-3f15-49a2-ab8f-bcf25a2bc29f/bin/python
machine: Linux-5.15.0-1038-azure-x86_64-with-glibc2.31
Python dependencies:
sklearn: 1.2.2
pip: 21.2.4
setuptools: 58.0.4
numpy: 1.20.3
scipy: 1.7.1
Cython: 0.29.24
pandas: 1.3.4
matplotlib: 3.4.3
joblib: 1.2.0
threadpoolctl: 2.2.0
Built with OpenMP: True
threadpoolctl info:
filepath: /databricks/python3/lib/python3.9/site-packages/numpy.libs/libopenblasp-r0-5bebc122.3.13.dev.so
prefix: libopenblas
user_api: blas
internal_api: openblas
version: 0.3.13.dev
num_threads: 6
threading_layer: pthreads
architecture: Haswell
filepath: /local_disk0/.ephemeral_nfs/envs/pythonEnv-67c47d19-3f15-49a2-ab8f-bcf25a2bc29f/lib/python3.9/site-packages/scikit_learn.libs/libgomp-a34b3233.so.1.0.0
prefix: libgomp
user_api: openmp
internal_api: openmp
version: None
num_threads: 6
filepath: /databricks/python3/lib/python3.9/site-packages/scipy.libs/libopenblasp-r0-085ca80a.3.9.so
prefix: libopenblas
user_api: blas
internal_api: openblas
version: 0.3.9
num_threads: 6
threading_layer: pthreads
architecture: Haswell
Metadata
Metadata
Assignees
Type
Projects
Status
Easy