8000 Merged in sergem/numpy/partition_for_unsigned (pull request #9) · prototype99/numpy@24deb37 · GitHub
[go: up one dir, main page]

Skip to content

Commit 24deb37

Browse files
committed
Merged in sergem/numpy/partition_for_unsigned (pull request ganesh-k13#9)
Partition_for_unsigned
2 parents dd68e27 + 3903aba commit 24deb37

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

numpy/core/_partition_use.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
_type_to_suff = dict(zip(list_type, list_suff))
88
_dtype_to_cffi_type = {dtype('int32'): 'npy_int',
99
dtype('int64'): 'npy_longlong',
10+
dtype('uint32'): 'npy_uint',
11+
dtype('uint64'): 'npy_ulonglong',
1012
dtype('float64'): 'npy_double',
1113
dtype('float32'): 'npy_float',
1214
}

numpy/core/tests/test_partition.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import division, absolute_import, print_function
2+
3+
import numpy as np
4+
from numpy.testing import TestCase
5+
import sys
6+
7+
8+
def is_pypy():
9+
return '__pypy__' in sys.builtin_module_names
10+
11+
12+
class TestPartition(TestCase):
13+
pivot = 2
14+
pivot_index = 2
15+
16+
def test_uint16(self):
17+
arr = np.arange(10, -1, -1, dtype='int16')
18+
19+
if is_pypy():
20+
with self.assertRaises(NotImplementedError):
21+
np.partition(arr, self.pivot_index)
22+
else:
23+
partition = np.partition(arr, self.pivot_index)
24+
self._check_partition(partition, self.pivot_index, self.pivot)
25+
26+
def test_uint32(self):
27+
arr = np.arange(10, -1, -1, dtype='int32')
28+
29+
partition = np.partition(arr, self.pivot_index)
30+
31+
self._check_partition(partition, self.pivot_index, self.pivot)
32+
33+
def test_uint64(self):
34+
arr = np.arange(10, -1, -1, dtype='int64')
35+
36+
partition = np.partition(arr, self.pivot_index)
37+
38+
self._check_partition(partition, self.pivot_index, self.pivot)
39+
40+
def _check_partition(self, partition, pivot_index, pivot):
41+
self.assertTrue(np.all(partition[:pivot_index] < partition[pivot_index]))
42+
self.assertTrue(np.all(partition[pivot_index:] >= partition[pivot_index]))
43+
self.assertTrue(partition[pivot_index] == pivot)

0 commit comments

Comments
 (0)
0