8000 align sessions · Issue #501 · larray-project/larray · GitHub
[go: up one dir, main page]

Skip to content

align sessions #501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
gdementen opened th 8000 is issue Oct 25, 2017 · 1 comment
Open

align sessions #501

gdementen opened this issue Oct 25, 2017 · 1 comment

Comments

@gdementen
Copy link
Contributor

Here is some preliminary code:

from larray.util.misc import unique_list


def align(*sessions, join='outer', fill_value=nan):
    # TODO: support n sessions
    assert len(sessions) == 2
    if not all(isinstance(s, Session) for s in sessions):
        raise TypeError("Session.align only supports aligning with other Session objects")

    seen = set()
    all_keys = []
    for s in sessions:
        unique_list(s.keys(), all_keys, seen)

    res1, res2 = [], []
    for name in all_keys:
        arrays = [s.get(name, np.nan) for s in sessions]
        arrays = [aslarray(a) for a in arrays]
        # rename anonymous axes because those are not supported by align()
        arrays = [array.rename({axis_num: axis.name if axis.name is not None else 'axis{}'.format(axis_num)
                                for axis_num, axis in enumerate(array.axes)})
                  for array in arrays]
        assert len(arrays) == 2
        arr1, arr2 = arrays
        try:
            # TODO: align should support more than one "other" array
            aligned1, aligned2 = arr1.align(arr2, join=join, fill_value=fill_value)
        except Exception:
            aligned1, aligned2 = np.nan, np.nan
        res1.append((name, aligned1))
        res2.append((name, aligned2))
    return Session(res1), Session(res2)


session1, session2 = align(session1, session2)
@alixdamman alixdamman added this to the 0.31 milestone Dec 21, 2018
@gdementen gdementen removed this from the 0.31 milestone Aug 1, 2019
@alixdamman alixdamman added this to the nice_to_have milestone Oct 10, 2019
@gdementen
Copy link
Contributor Author

Here is some UNTESTED code to generalize the above:

def align_all(arrays, axes=None, fill_value=nan, join='outer'):
    axes_collections = [arr.axes for arr in arrays]
    # if axes not specified
    if axes is None:
        # and we have only anonymous axes
        if all(name is None for col in axes_collections for name in col.names):
            # use N first axes by position
            join_axes = list(range(min(len(col) for col in axes_collections)))
        elif any(name is None for col in axes_collections for name in col.names):
            raise ValueError("axes collections with mixed anonymous/non anonymous axes are not supported by align"
                             "without specifying axes explicitly")
        else:
            assert all(name is not None for col in axes_collections for name in col.names)
            # use all common axes
            join_axes = OrderedSet(axes_collections[0].names)
            for col in axes_collections[1:]:
                join_axes &= OrderedSet(col.names)
    else:
        if isinstance(axes, (int, str, Axis)):
            axes = [axes]
        join_axes = axes
    aligned_axes = []
    for axis_ref in join_axes:
        aligned_axis = axes_collections[0][axis_ref]
        for col in axes_collections[1:]:
            aligned_axis = aligned_axis.align(col[axis_ref], join=join)
        aligned_axes.append(aligned_axis)
    axes_changes = list(zip(join_axes, aligned_axes))
    return tuple(arr.reindex(arr.axes.replace(axes_changes), fill_value=fill_value) for arr in arrays)

def align(*sessions, join='outer', fill_value=nan):
    if not all(isinstance(s, Session) for s in sessions):
        raise TypeError("Session.align only supports aligning with other Session objects")

    seen = set()
    all_keys = []
    for s in sessions:
        unique_list(s.keys(), all_keys, seen)

    # list of [[(name0, aligned_array_for_name0), ...], [(name0, aligned_array_for_name0), ...], ...]
    aligned_items = [[] for s in sessions]
    for name in all_keys:
        arrays = [s.get(name, np.nan) for s in sessions]
        arrays = [aslarray(a) for a in arrays]
        # rename anonymous axes because those are not supported by align()
        arrays = [array.rename({axis_num: axis.name if axis.name is not None else f'axis{axis_num}'
                                for axis_num, axis in enumerate(array.axes)})
                  for array in arrays]
        try:
            aligned_arrays = align_all(arrays, join=join, fill_value=fill_value)
        except Exception as e:
            print(e)
            raise
            aligned_arrays = tuple(np.nan for s in sessions)
        for res, array in zip(aligned_items, aligned_arrays):
            res.append((name, array))
    # create sessions with the results
    return tuple(Session(sess_items) for sess_items in aligned_items)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants
0