diff --git a/numpy/fft/fftpack.py b/numpy/fft/fftpack.py index 78cf214d27d3..8dc3eccbcc27 100644 --- a/numpy/fft/fftpack.py +++ b/numpy/fft/fftpack.py @@ -55,12 +55,13 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti, raise ValueError("Invalid number of FFT data points (%d) specified." % n) - try: - # Thread-safety note: We rely on list.pop() here to atomically - # retrieve-and-remove a wsave from the cache. This ensures that no - # other thread can get the same wsave while we're using it. - wsave = fft_cache.setdefault(n, []).pop() - except (IndexError): + # We have to ensure that only a single thread can access a wsave array + # at any given time. Thus we remove it from the cache and insert it + # again after it has been used. Multiple threads might create multiple + # copies of the wsave array. This is intentional and a limitation of + # the current C code. + wsave = fft_cache.pop_twiddle_factors(n) + if wsave is None: wsave = init_function(n) if a.shape[axis] != n: @@ -86,7 +87,7 @@ def _raw_fft(a, n=None, axis=-1, init_function=fftpack.cffti, # As soon as we put wsave back into the cache, another thread could pick it # up and start using it, so we must not do this until after we're # completely done using it ourselves. - fft_cache[n].append(wsave) + fft_cache.put_twiddle_factors(n, wsave) return r diff --git a/numpy/fft/helper.py b/numpy/fft/helper.py index 5d51c1a24980..0832bc5a49f8 100644 --- a/numpy/fft/helper.py +++ b/numpy/fft/helper.py @@ -4,7 +4,8 @@ """ from __future__ import division, absolute_import, print_function -from collections import OrderedDict +import collections +import threading from numpy.compat import integer_types from numpy.core import ( @@ -228,7 +229,7 @@ def rfftfreq(n, d=1.0): class _FFTCache(object): """ - Cache for the FFT init functions as an LRU (least recently used) cache. + Cache for the FFT twiddle factors as an LRU (least recently used) cache. Parameters ---------- @@ -250,37 +251,73 @@ class _FFTCache(object): def __init__(self, max_size_in_mb, max_item_count): self._max_size_in_bytes = max_size_in_mb * 1024 ** 2 self._max_item_count = max_item_count - # Much simpler than inheriting from it and having to work around - # recursive behaviour. - self._dict = OrderedDict() - - def setdefault(self, key, value): - return self._dict.setdefault(key, value) - - def __getitem__(self, key): - # pop + add to move it to the end. - value = self._dict.pop(key) - self._dict[key] = value - self._prune_dict() - return value - - def __setitem__(self, key, value): - # Just setting is it not enough to move it to the end if it already - # exists. - try: - del self._dict[key] - except: - pass - self._dict[key] = value - self._prune_dict() - - def _prune_dict(self): + self._dict = collections.OrderedDict() + self._lock = threading.Lock() + + def put_twiddle_factors(self, n, factors): + """ + Store twiddle factors for an FFT of length n in the cache. + + Putting multiple twiddle factors for a certain n will store it multiple + times. + + Parameters + ---------- + n : int + Data length for the FFT. + factors : ndarray + The actual twiddle values. + """ + with self._lock: + # Pop + later add to move it to the end for LRU behavior. + # Internally everything is stored in a dictionary whose values are + # lists. + try: + value = self._dict.pop(n) + except KeyError: + value = [] + value.append(factors) + self._dict[n] = value + self._prune_cache() + + def pop_twiddle_factors(self, n): + """ + Pop twiddle factors for an FFT of length n from the cache. + + Will return None if the requested twiddle factors are not available in + the cache. + + Parameters + ---------- + n : int + Data length for the FFT. + + Returns + ------- + out : ndarray or None + The retrieved twiddle factors if available, else None. + """ + with self._lock: + if n not in self._dict or not self._dict[n]: + return None + # Pop + later add to move it to the end for LRU behavior. + all_values = self._dict.pop(n) + value = all_values.pop() + # Only put pack if there are still some arrays left in the list. + if all_values: + self._dict[n] = all_values + return value + + def _prune_cache(self): # Always keep at least one item. while len(self._dict) > 1 and ( len(self._dict) > self._max_item_count or self._check_size()): self._dict.popitem(last=False) def _check_size(self): - item_sizes = [_i[0].nbytes for _i in self._dict.values() if _i] + item_sizes = [sum(_j.nbytes for _j in _i) + for _i in self._dict.values() if _i] + if not item_sizes: + return False max_size = max(self._max_size_in_bytes, 1.5 * max(item_sizes)) return sum(item_sizes) > max_size diff --git a/numpy/fft/tests/test_helper.py b/numpy/fft/tests/test_helper.py index 9fd0e496db66..cb85755d20e4 100644 --- a/numpy/fft/tests/test_helper.py +++ b/numpy/fft/tests/test_helper.py @@ -79,86 +79,78 @@ class TestFFTCache(TestCase): def test_basic_behaviour(self): c = _FFTCache(max_size_in_mb=1, max_item_count=4) - # Setting - c[1] = [np.ones(2, dtype=np.float32)] - c[2] = [np.zeros(2, dtype=np.float32)] - # Getting - assert_array_almost_equal(c[1][0], np.ones(2, dtype=np.float32)) - assert_array_almost_equal(c[2][0], np.zeros(2, dtype=np.float32)) - # Setdefault - c.setdefault(1, [np.array([1, 2], dtype=np.float32)]) - assert_array_almost_equal(c[1][0], np.ones(2, dtype=np.float32)) - c.setdefault(3, [np.array([1, 2], dtype=np.float32)]) - assert_array_almost_equal(c[3][0], np.array([1, 2], dtype=np.float32)) - - self.assertEqual(len(c._dict), 3) + + # Put + c.put_twiddle_factors(1, np.ones(2, dtype=np.float32)) + c.put_twiddle_factors(2, np.zeros(2, dtype=np.float32)) + + # Get + assert_array_almost_equal(c.pop_twiddle_factors(1), + np.ones(2, dtype=np.float32)) + assert_array_almost_equal(c.pop_twiddle_factors(2), + np.zeros(2, dtype=np.float32)) + + # Nothing should be left. + self.assertEqual(len(c._dict), 0) + + # Now put everything in twice so it can be retrieved once and each will + # still have one item left. + for _ in range(2): + c.put_twiddle_factors(1, np.ones(2, dtype=np.float32)) + c.put_twiddle_factors(2, np.zeros(2, dtype=np.float32)) + assert_array_almost_equal(c.pop_twiddle_factors(1), + np.ones(2, dtype=np.float32)) + assert_array_almost_equal(c.pop_twiddle_factors(2), + np.zeros(2, dtype=np.float32)) + self.assertEqual(len(c._dict), 2) def test_automatic_pruning(self): - # Thats around 2600 single precision samples. + # That's around 2600 single precision samples. c = _FFTCache(max_size_in_mb=0.01, max_item_count=4) - c[1] = [np.ones(200, dtype=np.float32)] - c[2] = [np.ones(200, dtype=np.float32)] - # Don't raise errors. - c[1], c[2], c[1], c[2] + c.put_twiddle_factors(1, np.ones(200, dtype=np.float32)) + c.put_twiddle_factors(2, np.ones(200, dtype=np.float32)) + self.assertEqual(list(c._dict.keys()), [1, 2]) # This is larger than the limit but should still be kept. - c[3] = [np.ones(3000, dtype=np.float32)] - # Should exist. - c[1], c[2], c[3] + c.put_twiddle_factors(3, np.ones(3000, dtype=np.float32)) + self.assertEqual(list(c._dict.keys()), [1, 2, 3]) # Add one more. - c[4] = [np.ones(3000, dtype=np.float32)] - + c.put_twiddle_factors(4, np.ones(3000, dtype=np.float32)) # The other three should no longer exist. - with self.assertRaises(KeyError): - c[1] - with self.assertRaises(KeyError): - c[2] - with self.assertRaises(KeyError): - c[3] + self.assertEqual(list(c._dict.keys()), [4]) # Now test the max item count pruning. c = _FFTCache(max_size_in_mb=0.01, max_item_count=2) - c[1] = [np.empty(2)] - c[2] = [np.empty(2)] + c.put_twiddle_factors(2, np.empty(2)) + c.put_twiddle_factors(1, np.empty(2)) # Can still be accessed. - c[2], c[1] - - c[3] = [np.empty(2)] + self.assertEqual(list(c._dict.keys()), [2, 1]) + c.put_twiddle_factors(3, np.empty(2)) # 1 and 3 can still be accessed - c[2] has been touched least recently # and is thus evicted. - c[1], c[3] - - with self.assertRaises(KeyError): - c[2] - - c[1], c[3] + self.assertEqual(list(c._dict.keys()), [1, 3]) # One last test. We will add a single large item that is slightly # bigger then the cache size. Some small items can still be added. c = _FFTCache(max_size_in_mb=0.01, max_item_count=5) - c[1] = [np.ones(3000, dtype=np.float32)] - c[1] - c[2] = [np.ones(2, dtype=np.float32)] - c[3] = [np.ones(2, dtype=np.float32)] - c[4] = [np.ones(2, dtype=np.float32)] - c[1], c[2], c[3], c[4] - - # One more big item. - c[5] = [np.ones(3000, dtype=np.float32)] - - # c[1] no longer in the cache. Rest still in the cache. - c[2], c[3], c[4], c[5] - with self.assertRaises(KeyError): - c[1] + c.put_twiddle_factors(1, np.ones(3000, dtype=np.float32)) + c.put_twiddle_factors(2, np.ones(2, dtype=np.float32)) + c.put_twiddle_factors(3, np.ones(2, dtype=np.float32)) + c.put_twiddle_factors(4, np.ones(2, dtype=np.float32)) + self.assertEqual(list(c._dict.keys()), [1, 2, 3, 4]) + + # One more big item. This time it is 6 smaller ones but they are + # counted as one big item. + for _ in range(6): + c.put_twiddle_factors(5, np.ones(500, dtype=np.float32)) + # '1' no longer in the cache. Rest still in the cache. + self.assertEqual(list(c._dict.keys()), [2, 3, 4, 5]) # Another big item - should now be the only item in the cache. - c[6] = [np.ones(4000, dtype=np.float32)] - for _i in range(1, 6): - with self.assertRaises(KeyError): - c[_i] - c[6] + c.put_twiddle_factors(6, np.ones(4000, dtype=np.float32)) + self.assertEqual(list(c._dict.keys()), [6]) if __name__ == "__main__":