diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 640f4fa32d04..ce76f4963bd3 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -356,6 +356,16 @@ def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, """ import gzip + try: + from pathlib import Path + supports_pathlib = True + except: + supports_pathlib = False + + if supports_pathlib: + if isinstance(file, Path): + file = file.__str__() + own_fid = False if isinstance(file, basestring): @@ -505,10 +515,11 @@ def savez(file, *args, **kwds): Parameters ---------- - file : str or file - Either the file name (string) or an open file (file-like object) - where the data will be saved. If file is a string, the ``.npz`` - extension will be appended to the file name if it is not already there. + file : str, file, or Path (from pathlib library) + Either the file name (string), an open file (file-like object) + where the data will be saved, or a Path from the pathlib library. + If file is a string, the ``.npz`` extension will be appended to + the file name if it is not already there. args : Arguments, optional Arrays to save to the file. Since it is not possible for Python to know the names of the arrays outside `savez`, the arrays will be saved @@ -584,7 +595,7 @@ def savez_compressed(file, *args, **kwds): Parameters ---------- - file : str + file : str, file, or Path (from pathlib library) File name of ``.npz`` file. args : Arguments Function arguments. @@ -606,10 +617,18 @@ def _savez(file, args, kwds, compress, allow_pickle=True, pickle_kwargs=None): import zipfile # Import deferred for startup time improvement import tempfile + try: + from pathlib import Path + supports_pathlib = True + except: + supports_pathlib = False if isinstance(file, basestring): if not file.endswith('.npz'): file = file + '.npz' + elif supports_pathlib: + if isinstance(file, Path): + file = file.__str__() namedict = kwds for i, val in enumerate(args): diff --git a/numpy/lib/tests/test_format.py b/numpy/lib/tests/test_format.py index 4f8a651489da..114497c2239c 100644 --- a/numpy/lib/tests/test_format.py +++ b/numpy/lib/tests/test_format.py @@ -523,6 +523,19 @@ def test_compressed_roundtrip(): arr1 = np.load(npz_file)['arr'] assert_array_equal(arr, arr1) +def test_compressed_roundtrip(): + arr = np.random.rand(200, 200) + npz_file = os.path.join(tempdir, 'compressed.npz') + try: + import pathlib + except: + return + npz_file = pathlib.Path(npz_file) + np.savez_compressed(npz_file, arr=arr) + arr1 = np.load(npz_file)['arr'] + assert_array_equal(arr, arr1) + + def test_python2_python3_interoperability(): if sys.version_info[0] >= 3: