10000 ENH: lib: break reference cycle in NpzFile (#2048) · numpy-buildbot/numpy@c4482f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit c4482f5

Browse files
pvcharris
authored andcommitted
ENH: lib: break reference cycle in NpzFile (numpy#2048)
This allows these objects to be freed by refcount, rather than requiring the gc, which can be useful in some situations.
1 parent 56a5472 commit c4482f5

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

numpy/lib/npyio.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111
import itertools
1212
import warnings
13+
import weakref
1314
from operator import itemgetter
1415

1516
from cPickle import load as _cload, loads
@@ -108,7 +109,8 @@ class BagObj(object):
108109
109110
"""
110111
def __init__(self, obj):
111-
self._obj = obj
112+
# Use weakref to make NpzFile objects collectable by refcount
113+
self._obj = weakref.proxy(obj)
112114
def __getattribute__(self, key):
113115
try:
114116
return object.__getattribute__(self, '_obj')[key]
@@ -212,6 +214,7 @@ def close(self):
212214
if self.fid is not None:
213215
self.fid.close()
214216
self.fid = None
217+
self.f = None # break reference cycle
215218

216219
def __del__(self):
217220
self.close()

numpy/lib/tests/test_io.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import time
77
from datetime import datetime
88
import warnings
9+
import gc
910
from numpy.testing.utils import WarningManager
1011

1112
import numpy as np
@@ -1525,6 +1526,20 @@ def test_npzfile_dict():
15251526

15261527
assert_('x' in list(z.iterkeys()))
15271528

1529+
def test_load_refcount():
1530+
# Check that objects returned by np.load are directly freed based on
1531+
# their refcount, rather than needing the gc to collect them.
1532+
1533+
f = StringIO()
1534+
np.savez(f, [1, 2, 3])
1535+
f.seek(0)
1536+
1537+
gc.collect()
1538+
n_before = len(gc.get_objects())
1539+
np.load(f)
1540+
n_after = len(gc.get_objects())
1541+
1542+
assert_equal(n_before, n_after)
15281543

15291544
if __name__ == "__main__":
15301545
run_module_suite()

0 commit comments

Comments
 (0)
0