8000 MNT: refactor stringdtype compatibility checking out of common_instance · numpy/numpy@b356e04 · GitHub
[go: up one dir, main page]

Skip to content

Commit b356e04

Browse files
committed
MNT: refactor stringdtype compatibility checking out of common_instance
1 parent 9817861 commit b356e04

File tree

3 files changed

+96
-41
lines changed

3 files changed

+96
-41
lines changed

numpy/_core/src/multiarray/stringdtype/dtype.c

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,13 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
145145
}
146146

147147
NPY_NO_EXPORT int
148-
na_eq_cmp(PyObject *a, PyObject *b, int coerce_nulls) {
148+
na_eq_cmp(PyObject *a, PyObject *b) {
149149
if (a == b) {
150150
// catches None and other singletons like Pandas.NA
151151
return 1;
152152
}
153153
if (a == NULL || b == NULL) {
154-
if (coerce_nulls) {
155-
// an object with an explictly set NA object is considered
156-
// compatible for binary operations to one with no explicitly set NA
157-
return 1;
158-
}
159-
else {
160-
return 0;
161-
}
154+
return 0;
162155
}
163156
if (PyFloat_Check(a) && PyFloat_Check(b)) {
164157
// nan check catches np.nan and float('nan')
@@ -189,29 +182,66 @@ _eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona)
189182
if (scoerce != ocoerce) {
190183
return 0;
191184
}
192-
return na_eq_cmp(sna, ona, 0);
185+
return na_eq_cmp(sna, ona);
186+
}
187+
188+
// currently this can only return 1 or -1, the latter indicating that the
189+
// error indicator is set
190+
NPY_NO_EXPORT int
191+
stringdtype_compatible_na(PyObject *na1, PyObject *na2) {
192+
if ((na1 == NULL) != (na2 == NULL)) {
193+
return 1;
194+
}
195+
196+
int na_eq = na_eq_cmp(na1, na2);
197+
198+
if (na_eq < 0) {
199+
return -1;
200+
}
201+
else if (na_eq == 0) {
202+
PyErr_Format(PyExc_TypeError,
203+
"Cannot find a compatible null string value for "
204+
"null strings '%R' and '%R'", na1, na2);
205+
return -1;
206+
}
207+
return 1;
208+
}
209+
210+
NPY_NO_EXPORT int
211+
stringdtype_compatible_settings(PyObject *na1, PyObject *na2, PyObject **out_na,
212+
int coerce1, int coerce2, int *out_coerce) {
213+
int compatible = stringdtype_compatible_na(na1, na2);
214+
if (compatible == -1) {
215+
return -1;
216+
}
217+
*out_na = (na1 ? na1 : na2);
218+
*out_coerce = (coerce1 && coerce2);
219+
220+
return 0;
193221
}
194222

195223
/*
196224
* This is used to determine the correct dtype to return when dealing
197225
* with a mix of different dtypes (for example when creating an array
198226
* from a list of scalars).
199227
*/
200-
NPY_NO_EXPORT PyArray_StringDTypeObject *
228+
static PyArray_StringDTypeObject *
201229
common_instance(PyArray_StringDTypeObject *dtype1, PyArray_StringDTypeObject *dtype2)
202230
{
203-
int eq = na_eq_cmp(dtype1->na_object, dtype2->na_object, 1);
204-
205-
if (eq <= 0) {
206-
PyErr_SetString(
207-
PyExc_TypeError,
208-
"Cannot find common instance for incompatible dtype instances");
231+
int out_coerce = 1;
232+
PyObject *out_na_object = NULL;
233+
234+
if (stringdtype_compatible_settings(
235+
dtype1->na_object, dtype2->na_object, &out_na_object,
236+
dtype1->coerce, dtype2->coerce, &out_coerce) == -1) {
237+
PyErr_Format(PyExc_TypeError,
238+
"Cannot find common instance for incompatible dtypes "
239+
"'%R' and '%R'", (PyObject *)dtype1, (PyObject *)dtype2);
209240
return NULL;
210241
}
211242

212243
return (PyArray_StringDTypeObject *)new_stringdtype_instance(
213-
dtype1->na_object != NULL ? dtype1->na_object : dtype2->na_object,
214-
!((dtype1->coerce == 0) || (dtype2->coerce == 0)));
244+
out_na_object, out_coerce);
215245
}
216246

217247
/*
@@ -301,7 +331,7 @@ stringdtype_setitem(PyArray_StringDTypeObject *descr, PyObject *obj, char **data
301331
// so we do the comparison before acquiring the allocator.
302332

303333
if (na_object != NULL) {
304-
na_cmp = na_eq_cmp(obj, na_object, 1);
334+
na_cmp = na_eq_cmp(obj, na_object);
305335
if (na_cmp == -1) {
306336
return -1;
307337
}

numpy/_core/src/multiarray/stringdtype/dtype.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,12 @@ stringdtype_finalize_descr(PyArray_Descr *dtype);
4949
NPY_NO_EXPORT int
5050
_eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona);
5151

52-
NPY_NO_EXPORT PyArray_StringDTypeObject *
53-
common_instance(PyArray_StringDTypeObject *dtype1, PyArray_StringDTypeObject *dtype2);
52+
NPY_NO_EXPORT int
53+
stringdtype_compatible_settings(PyObject *na1, PyObject *na2, PyObject **out_na,
54+
int coerce1, int coerce2, int *out_coerce);
55+
56+
NPY_NO_EXPORT int
57+
stringdtype_compatible_na(PyObject *na1, PyObject *na2);
5458

5559
#ifdef __cplusplus
5660
}

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,12 @@ binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
246246
{
247247
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
248248
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
249-
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
249+
int out_coerce = 1;
250+
PyObject *out_na_object = NULL;
250251

251-
if (common_descr == NULL) {
252+
if (stringdtype_compatible_settings(
253+
descr1->na_object, descr2->na_object, &out_na_object,
254+
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
252255
return (NPY_CASTING)-1;
253256
}
254257

@@ -261,7 +264,7 @@ binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
261264

262265
if (given_descrs[2] == NULL) {
263266
out_descr = (PyArray_Descr *)new_stringdtype_instance(
264-
common_descr->na_object, common_descr->coerce);
267+
out_na_object, out_coerce);
265268

266269
if (out_descr == NULL) {
267270
return (NPY_CASTING)-1;
@@ -552,9 +555,8 @@ string_comparison_resolve_descriptors(
552555
{
553556
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
554557
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
555-
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
556558

557-
if (common_descr == NULL) {
559+
if (stringdtype_compatible_na(descr1->na_object, descr2->na_object) == -1) {
558560
return (NPY_CASTING)-1;
559561
}
560562

@@ -784,9 +786,12 @@ string_findlike_resolve_descriptors(
784786
{
785787
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
786788
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
787-
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
789+
int out_coerce = 1;
790+
PyObject *out_na_object = NULL;
788791

789-
if (common_descr == NULL) {
792+
if (stringdtype_compatible_settings(
793+
descr1->na_object, descr2->na_object, &out_na_object,
794+
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
790795
return (NPY_CASTING)-1;
791796
}
792797

@@ -834,9 +839,12 @@ string_startswith_endswith_resolve_descriptors(
834839
{
835840
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descr 10000 s[0];
836841
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
837-
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
842+
int out_coerce = 1;
843+
PyObject *out_na_object = NULL;
838844

839-
if (common_descr == NULL) {
845+
if (stringdtype_compatible_settings(
846+
descr1->na_object, descr2->na_object, &out_na_object,
847+
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
840848
return (NPY_CASTING)-1;
841849
}
842850

@@ -1242,11 +1250,18 @@ replace_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
12421250
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
12431251
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
12441252
PyArray_StringDTypeObject *descr3 = (PyArray_StringDTypeObject *)given_descrs[2];
1253+
int out_coerce = 1;
1254+
PyObject *out_na_object = NULL;
12451255

1246-
PyArray_StringDTypeObject *common_descr = common_instance(
1247-
common_instance(descr1, descr2), descr3);
1256+
if (stringdtype_compatible_settings(
1257+
descr1->na_object, descr2->na_object, &out_na_object,
1258+
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
1259+
return (NPY_CASTING)-1;
1260+
}
12481261

1249-
if (common_descr == NULL) {
1262+
if (stringdtype_compatible_settings(
1263+
out_na_object, descr3->na_object, &out_na_object,
1264+
out_coerce, descr3->coerce, &out_coerce) == -1) {
12501265
return (NPY_CASTING)-1;
12511266
}
12521267

@@ -1263,7 +1278,7 @@ replace_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
12631278

12641279
if (given_descrs[4] == NULL) {
12651280
out_descr = (PyArray_Descr *)new_stringdtype_instance(
1266-
common_descr->na_object, common_descr->coerce);
1281+
out_na_object, out_coerce);
12671282

12681283
if (out_descr == NULL) {
12691284
return (NPY_CASTING)-1;
@@ -1510,9 +1525,12 @@ center_ljust_rjust_resolve_descriptors(
15101525
{
15111526
PyArray_StringDTypeObject *input_descr = (PyArray_StringDTypeObject *)given_descrs[0];
15121527
PyArray_StringDTypeObject *fill_descr = (PyArray_StringDTypeObject *)given_descrs[2];
1513-
PyArray_StringDTypeObject *common_descr = common_instance(input_descr, fill_descr);
1528+
int out_coerce = 1;
1529+
PyObject *out_na_object = NULL;
15141530

1515-
if (common_descr == NULL) {
1531+
if (stringdtype_compatible_settings(
1532+
input_descr->na_object, fill_descr->na_object, &out_na_object,
1533+
input_descr->coerce, fill_descr->coerce, &out_coerce) == -1) {
15161534
return (NPY_CASTING)-1;
15171535
}
15181536

@@ -1527,7 +1545,7 @@ center_ljust_rjust_resolve_descriptors(
15271545

15281546
if (given_descrs[3] == NULL) {
15291547
out_descr = (PyArray_Descr *)new_stringdtype_instance(
1530-
common_descr->na_object, common_descr->coerce);
1548+
out_na_object, out_coerce);
15311549

15321550
if (out_descr == NULL) {
15331551
return (NPY_CASTING)-1;
@@ -1817,9 +1835,12 @@ string_partition_resolve_descriptors(
18171835

18181836
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
18191837
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
1820-
PyArray_StringDTypeObject *common_descr = common_instance(descr1, descr2);
1838+
int out_coerce = 1;
1839+
PyObject *out_na_object = NULL;
18211840

1822-
if (common_descr == NULL) {
1841+
if (stringdtype_compatible_settings(
1842+
descr1->na_object, descr2->na_object, &out_na_object,
1843+
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
18231844
return (NPY_CASTING)-1;
18241845
}
18251846

@@ -1830,7 +1851,7 @@ string_partition_resolve_descriptors(
18301851

18311852
for (int i=2; i<5; i++) {
18321853
loop_descrs[i] = (PyArray_Descr *)new_stringdtype_instance(
1833-
common_descr->na_object, common_descr->coerce);
1854+
out_na_object, out_coerce);
18341855
if (loop_descrs[i] == NULL) {
18351856
return (NPY_CASTING)-1;
18361857
}

0 commit comments

Comments
 (0)
0