8000 ENH: make numpy.lib._arraysetops.intersect1d work on multiple arrays · numpy/numpy@e2c956e · GitHub
[go: up one dir, main page]

Skip to content

Commit e2c956e

Browse files
committed
ENH: make numpy.lib._arraysetops.intersect1d work on multiple arrays
Intersect1d can be used with multiple arrays now and also returns the indices of all arrays when using `return_indices=True`.
1 parent 578670b commit e2c956e

File tree

1 file changed

+48
-38
lines changed

1 file changed

+48
-38
lines changed

numpy/lib/_arraysetops_impl.py

Lines changed: 48 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -571,28 +571,28 @@ def unique_values(x):
571571

572572

573573
def _intersect1d_dispatcher(
574-
ar1, ar2, assume_unique=None, return_indices=None):
575-
return (ar1, ar2)
574+
*ars, assume_unique=None, return_indices=None):
575+
return ars
576576

577577

578578
@array_function_dispatch(_intersect1d_dispatcher)
579-
def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
579+
def intersect1d(*ars, assume_unique=False, return_indices=False):
580580
"""
581-
Find the intersection of two arrays.
581+
Find the intersection of multiple arrays.
582582
583-
Return the sorted, unique values that are in both of the input arrays.
583+
Return the sorted, unique values that are in all of the input arrays.
584584
585585
Parameters
586586
----------
587-
ar1, ar2 : array_like
588-
Input arrays. Will be flattened if not already 1D.
587+
*ars : array_like
588+
Input arrays. Each will be flattened if not already 1D.
589589
assume_unique : bool
590-
If True, the input arrays are both assumed to be unique, which
591-
can speed up the calculation. If True but ``ar1`` or ``ar2`` are not
592-
unique, incorrect results and out-of-bounds indices could result.
590+
If True, the input arrays are all assumed to be unique, which
591+
can speed up the calculation. If True but any of the arrays in ars is
592+
not unique, incorrect results and out-of-bounds indices could result.
593593
Default is False.
594594
return_indices : bool
595-
If True, the indices which correspond to the intersection of the two
595+
If True, the indices which correspond to the intersection of all
596596
arrays are returned. The first instance of a value is used if there are
597597
multiple. Default is False.
598598
@@ -602,23 +602,28 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
602602
-------
603603
intersect1d : ndarray
604604
Sorted 1D array of common and unique elements.
605-
comm1 : ndarray
606-
The indices of the first occurrences of the common values in `ar1`.
607-
Only provided if `return_indices` is True.
608-
comm2 : ndarray
609-
The indices of the first occurrences of the common values in `ar2`.
605+
*comms : list of ndarray
606+
The indices of the first occurrences of the common values in `ars`.
610607
Only provided if `return_indices` is True.
608+
comms[0] contains the indices for ars[0],
609+
comms[1] contains the indices for ars[1] and so on
610+
611+
612+
See Also
613+
--------
614+
numpy.lib.arraysetops : Module with a number of other functions for
615+
performing set operations on arrays.
611616
612617
Examples
613618 10000
--------
614619
>>> np.intersect1d([1, 3, 4, 3], [3, 1, 2, 1])
615620
array([1, 3])
616621
617-
To intersect more than two arrays, use functools.reduce:
622+
To intersect more than two arrays, use:
618623
619-
>>> from functools import reduce
620-
>>> reduce(np.intersect1d, ([1, 3, 4, 3], [3, 1, 2, 1], [6, 3, 4, 2]))
621-
array([3])
624+
>>> ars = ([1, 3, 4, 3], [3, 1, 2, 1], [6, 3, 4, 2])
625+
>>> intersect1d(*ars, return_indices=True)
626+
(array([3]), array([1]), array([0]), array([1]))
622627
623628
To return the indices of the values common to the input arrays
624629
along with the intersected values:
@@ -632,38 +637,43 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
632637
(array([1, 2, 4]), array([1, 2, 4]), array([1, 2, 4]))
633638
634639
"""
635-
ar1 = np.asanyarray(ar1)
636-
ar2 = np.asanyarray(ar2)
640+
ars = [np.asanyarray(ar) for ar in ars]
637641

638642
if not assume_unique:
639643
if return_indices:
640-
ar1, ind1 = unique(ar1, return_index=True)
641-
ar2, ind2 = unique(ar2, return_index=True)
644+
inds = [None] * len(ars)
645+
for i, ar in enumerate(ars):
646+
ars[i], inds[i] = unique(ar, return_index=True)
642647
else:
643-
ar1 = unique(ar1)
644-
ar2 = unique(ar2)
648+
for i, ar in enumerate(ars):
649+
ars[i] = unique(ar)
645650
else:
646-
ar1 = ar1.ravel()
647-
ar2 = ar2.ravel()
651+
for i, ar in enumerate(ars):
652+
ars[i] = ar.ravel()
648653

649-
aux = np.concatenate((ar1, ar2))
654+
aux = np.concatenate(ars)
650655
if return_indices:
651656
aux_sort_indices = np.argsort(aux, kind='mergesort')
652657
aux = aux[aux_sort_indices]
653658
else:
654659
aux.sort()
655660

656-
mask = aux[1:] == aux[:-1]
657-
int1d = aux[:-1][mask]
661+
# aux is sorted and each array in ars has only unique elements.
662+
# The same element in a distance of len(ars)+1 away means,
663+
# that the element must have been in each of the arrays in ars.
664+
mask = aux[:-len(ars)+1] == aux[len(ars)-1:]
665+
int1d = aux[:-len(ars)+1][mask]
658666

659667
if return_indices:
660-
ar1_indices = aux_sort_indices[:-1][mask]
661-
ar2_indices = aux_sort_indices[1:][mask] - ar1.size
662-
if not assume_unique:
663-
ar1_indices = ind1[ar1_indices]
664-
ar2_indices = ind2[ar2_indices]
665-
666-
return int1d, ar1_indices, ar2_indices
668+
ret_indizes = [None] * len(ars)
669+
offset = 0
670+
for i, ar in enumerate(ars):
671+
imax = aux_sort_indices.size - len(ars) + i + 1
672+
ret_indizes[i] = aux_sort_indices[i:imax][mask] - offset
673+
offset += ar.size
674+
if not assume_unique:
675+
ret_indizes[i] = inds[i][ret_indizes[i]]
676+
return int1d, *ret_indizes
667677
else:
668678
return int1d
669679

0 commit comments

Comments
 (0)
0