8000 FEAT: add support for filepaths (str or Path) in compare (closes #229) · larray-project/larray-editor@ba4dc63 · GitHub
[go: up one dir, main page]

Skip to content

Commit ba4dc63

Browse files
committed
FEAT: add support for filepaths (str or Path) in compare (closes #229)
also changed internal API to use only dict
1 parent bb77a65 commit ba4dc63

File tree

4 files changed

+54
-24
lines changed

4 files changed

+54
-24
lines changed

doc/source/changes/version_0_34.rst.inc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ Miscellaneous improvements
4343
* when displaying an expression (computed array), the window title includes the actual expression
4444
instead of using '<expr>'.
4545

46+
* :py:obj:`compare()` can now take filepaths as argument (and will load them as a Session) to make
47+
comparing a in-memory Session with an earlier Session saved on the disk. Those filepaths
48+
can be given as both str or Path objects. Closes :editor_issue:`229`.
49+
4650
* added support for Path objects (in addition to str representing paths) in :py:obj:`view()` and :py:obj:`edit()`.
4751
See :issue:`896`.
4852

larray_editor/api.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ def _compare_dialog(parent, *args, **kwargs):
256256
else:
257257
caller_info = None
258258

259-
if any(isinstance(a, la.Session) for a in args):
259+
compare_sessions = any(isinstance(a, (la.Session, str, Path)) for a in args)
260+
if compare_sessions:
260261
from larray_editor.comparator import SessionComparator
261262
dlg = SessionComparator(parent)
262263
default_name = 'session'
@@ -267,15 +268,24 @@ def _compare_dialog(parent, *args, **kwargs):
267268

268269
if names is None:
269270
def get_name(i, obj, depth=0):
270-
obj_names = _find_names(obj, depth=depth + 1)
271-
return obj_names[0] if obj_names else f'{default_name} {i:d}'
271+
if isinstance(obj, (str, Path)):
272+
return os.path.basename(obj)
273+
else:
274+
obj_names = _find_names(obj, depth=depth + 1)
275+
return obj_names[0] if obj_names else f'{default_name} {i:d}'
272276

273277
# depth + 2 because of the list comprehension
274278
names = [get_name(i, a, depth=depth + 2) for i, a in enumerate(args)]
275279
else:
276280
assert isinstance(names, list) and len(names) == len(args)
277281

278-
if dlg.setup_and_check(args, names=names, title=title, caller_info=caller_info, **kwargs):
282+
if compare_sessions:
283+
args = [la.Session(a) if not isinstance(a, la.Session) else a
284+
for a in args]
285+
286+
data = dict(zip(names, args))
287+
288+
if dlg.setup_and_check(data, title=title, caller_info=caller_info, **kwargs):
279289
return dlg
280290
else:
281291
return None
@@ -287,8 +297,8 @@ def compare(*args, **kwargs):
287297
288298
Parameters
289299
----------
290-
*args : Arrays or Sessions
291-
Arrays or sessions to compare.
300+
*args : Arrays, Sessions, str or Path.
301+
Arrays or sessions to compare. Strings or Path will be loaded as Sessions from the corresponding files.
292302
title : str, optional
293303
Title for the window. Defaults to ''.
294304
names : list of str, optional

larray_editor/comparator.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
67E6 11
import ast
2+
import warnings
3+
24
import numpy as np
35
import larray as la
46

@@ -14,7 +16,7 @@
1416

1517
class ComparatorWidget(QWidget):
1618
"""Comparator Widget"""
17-
def __init__(self, parent=None, bg_gradient='red-white-blue', rtol=0, atol=0, nans_equal=True, **kwargs):
19+
def __init__(self, parent=None, bg_gradient='red-white-blue', rtol=0, atol=0, nans_equal=True):
1820
QWidget.__init__(self, parent)
1921

2022
layout = QVBoxLayout()
@@ -179,27 +181,33 @@ def _setup_and_check(self, widget, data, title, readonly, **kwargs):
179181
----------
180182
widget: QWidget
181183
Parent widget.
182-
data: list or tuple of Array, ndarray
183-
Arrays to compare.
184+
data: dict of Array
185+
Arrays to compare as a {name: Array} dict.
184186
title: str
185187
Title.
186188
readonly: bool
189+
Ignored argument (comparator is always read only)
187190
kwargs:
188191
189192
* rtol: int or float
190193
* atol: int or float
191194
* nans_equal: bool
192195
* bg_gradient: str
193-
* names: list of str
194196
"""
195-
arrays = [la.asarray(array) for array in data if isinstance(array, DISPLAY_IN_GRID)]
196-
names = kwargs.get('names', [f"Array{i}" for i in range(len(arrays))])
197+
if isinstance(data, (list, tuple)):
198+
names = kwargs.pop('names', [f"Array{i}" for i in range(len(data))])
199+
data = dict(zip(names, data))
200+
warnings.warn("For ArrayComparator.setup_and_check, using a list or tuple for the data argument, "
201+
"and using the names argument are both deprecated. Please use a dict instead",
202+
FutureWarning, stacklevel=3)
203+
204+
assert all(isinstance(s, la.Array) for s in data.values())
197205

198206
layout = QVBoxLayout()
199207
widget.setLayout(layout)
200208

201209
comparator_widget = ComparatorWidget(self, **kwargs)
202-
comparator_widget.set_data(arrays, la.Axis(names, 'array'))
210+
comparator_widget.set_data(data.values(), la.Axis(data.keys(), 'array'))
203211
layout.addWidget(comparator_widget)
204212

205213

@@ -224,26 +232,29 @@ def _setup_and_check(self, widget, data, title, readonly, **kwargs):
224232
----------
225233
widget: QWidget
226234
Parent widget.
227-
data: list or tuple of Session
228-
Sessions to compare.
235+
data: dict of Session
236+
Sessions to compare as a {name: Session} dict.
229237
title: str
230238
Title.
231239
readonly: bool
240+
Ignored argument (comparator is always read only)
232241
kwargs:
233242
234243
* rtol: int or float
235244
* atol: int or float
236245
* nans_equal: bool
237246
* bg_gradient: str
238-
* names: list of str
239-
* colors: str
240247
"""
241-
sessions = data
242-
names = kwargs.get('names', [f"Session{i}" for i in range(len(sessions))])
248+
if isinstance(data, (list, tuple)):
249+
names = kwargs.pop('names', [f"Session{i}" for i in range(len(data))])
250+
data = dict(zip(names, data))
251+
warnings.warn("For SessionComparator.setup_and_check, using a list or tuple for the data argument, "
252+
"and using the names argument are both deprecated. Please use a dict instead",
253+
FutureWarning, stacklevel=3)
243254

244-
assert all(isinstance(s, la.Session) for s in sessions)
245-
self.sessions = sessions
246-
self.stack_axis = la.Axis(names, 'session')
255+
assert all(isinstance(s, la.Session) for s in data.values())
256+
self.sessions = data.values()
257+
self.stack_axis = la.Axis(data.keys(), 'session')
247258

248259
layout = QVBoxLayout()
249260
widget.setLayout(layout)
@@ -260,6 +271,7 @@ def _setup_and_check(self, widget, data, title, readonly, **kwargs):
260271
self.listwidget = listwidget
261272

262273
comparatorwidget = ComparatorWidget(self, **kwargs)
274+
# do not call set_data on the comparatorwidget as it will be done by the setCurrentRow below
263275
self.arraywidget = comparatorwidget
264276

265277
main_splitter = QSplitter(Qt.Horizontal)

larray_editor/tests/test_api_larray.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,13 @@ def make_demo(width=20, ball_radius=5, path_radius=5, steps=30):
145145
# edit(arr2)
146146

147147
# compare(arr3, arr3 + 1.0)
148+
# compare(arr3, arr3 + 1.0, names=['arr3', 'arr3 + 1.0'])
148149
# compare(np.random.normal(0, 1, size=(10, 2)), np.random.normal(0, 1, size=(10, 2)))
149-
# compare(la.Session(arr4=arr4, arr3=arr3, data=data3),
150-
# la.Session(arr4=arr4 + 1.0, arr3=arr3 * 2.0, data=data3 * 1.05))
150+
# sess1 = la.Session(arr4=arr4, arr3=arr3, data=data3)
151+
# sess1.save('sess1.h5')
152+
# sess2 = la.Session(arr4=arr4 + 1.0, arr3=arr3 * 2.0, data=data3 * 1.05)
153+
# compare('sess1.h5', sess2)
154+
# compare(Path('sess1.h5'), sess2)
151155
# compare(la.Session(arr2=arr2, arr3=arr3),
152156
# la.Session(arr2=arr2 + 1.0, arr3=arr3 * 2.0))
153157

0 commit comments

Comments
 (0)
0