-
Notifications
You must be signed in to change notification settings - Fork 6
align sessions #501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Milestone
Comments
Here is some UNTESTED code to generalize the above: def align_all(arrays, axes=None, fill_value=nan, join='outer'):
axes_collections = [arr.axes for arr in arrays]
# if axes not specified
if axes is None:
# and we have only anonymous axes
if all(name is None for col in axes_collections for name in col.names):
# use N first axes by position
join_axes = list(range(min(len(col) for col in axes_collections)))
elif any(name is None for col in axes_collections for name in col.names):
raise ValueError("axes collections with mixed anonymous/non anonymous axes are not supported by align"
"without specifying axes explicitly")
else:
assert all(name is not None for col in axes_collections for name in col.names)
# use all common axes
join_axes = OrderedSet(axes_collections[0].names)
for col in axes_collections[1:]:
join_axes &= OrderedSet(col.names)
else:
if isinstance(axes, (int, str, Axis)):
axes = [axes]
join_axes = axes
aligned_axes = []
for axis_ref in join_axes:
aligned_axis = axes_collections[0][axis_ref]
for col in axes_collections[1:]:
aligned_axis = aligned_axis.align(col[axis_ref], join=join)
aligned_axes.append(aligned_axis)
axes_changes = list(zip(join_axes, aligned_axes))
return tuple(arr.reindex(arr.axes.replace(axes_changes), fill_value=fill_value) for arr in arrays)
def align(*sessions, join='outer', fill_value=nan):
if not all(isinstance(s, Session) for s in sessions):
raise TypeError("Session.align only supports aligning with other Session objects")
seen = set()
all_keys = []
for s in sessions:
unique_list(s.keys(), all_keys, seen)
# list of [[(name0, aligned_array_for_name0), ...], [(name0, aligned_array_for_name0), ...], ...]
aligned_items = [[] for s in sessions]
for name in all_keys:
arrays = [s.get(name, np.nan) for s in sessions]
arrays = [aslarray(a) for a in arrays]
# rename anonymous axes because those are not supported by align()
arrays = [array.rename({axis_num: axis.name if axis.name is not None else f'axis{axis_num}'
for axis_num, axis in enumerate(array.axes)})
for array in arrays]
try:
aligned_arrays = align_all(arrays, join=join, fill_value=fill_value)
except Exception as e:
print(e)
raise
aligned_arrays = tuple(np.nan for s in sessions)
for res, array in zip(aligned_items, aligned_arrays):
res.append((name, array))
# create sessions with the results
return tuple(Session(sess_items) for sess_items in aligned_items) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here is some preliminary code:
The text was updated successfully, but these errors were encountered: