8000 MAINT: respond to marten's comments · numpy/numpy@a39ed17 · GitHub
[go: up one dir, main page]

Skip to content

Commit a39ed17

Browse files
committed
MAINT: respond to marten's comments
1 parent 9ba3cd3 commit a39ed17

File tree

3 files changed

+39
-69
lines changed

3 files changed

+39
-69
lines changed

numpy/_core/src/multiarray/stringdtype/dtype.c

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ new_stringdtype_instance(PyObject *na_object, int coerce)
144144
return NULL;
145145
}
146146

147-
NPY_NO_EXPORT int
147+
static int
148148
na_eq_cmp(PyObject *a, PyObject *b) {
149149
if (a == b) {
150150
// catches None and other singletons like Pandas.NA
@@ -185,41 +185,30 @@ _eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona)
185185
return na_eq_cmp(sna, ona);
186186
}
187187

188-
// currently this can only return 1 or -1, the latter indicating that the
189-
// error indicator is set
188+
// Currently this can only return 1 or -1, the latter indicating that the
189+
// error indicator is set. Pass in out_na if you want to figure out which
190+
// na is valid.
190191
NPY_NO_EXPORT int
191-
stringdtype_compatible_na(PyObject *na1, PyObject *na2) {
192-
if ((na1 == NULL) != (na2 == NULL)) {
193-
return 1;
194-
}
195-
196-
int na_eq = na_eq_cmp(na1, na2);
192+
stringdtype_compatible_na(PyObject *na1, PyObject *na2, PyObject **out_na) {
193+
if ((na1 != NULL) && (na2 != NULL)) {
194+
int na_eq = na_eq_cmp(na1, na2);
197195

198-
if (na_eq < 0) {
199-
return -1;
196+
if (na_eq < 0) {
197+
return -1;
198+
}
199+
else if (na_eq == 0) {
200+
PyErr_Format(PyExc_TypeError,
201+
"Cannot find a compatible null string value for "
202+
"null strings '%R' and '%R'", na1, na2);
203+
return -1;
204+
}
200205
}
201-
else if (na_eq == 0) {
202-
PyErr_Format(PyExc_TypeError,
203-
"Cannot find a compatible null string value for "
204-
"null strings '%R' and '%R'", na1, na2);
205-
return -1;
206+
if (out_na != NULL) {
207+
*out_na = na1 ? na1 : na2;
206208
}
207209
return 1;
208210
}
209211

210-
NPY_NO_EXPORT int
211-
stringdtype_compatible_settings(PyObject *na1, PyObject *na2, PyObject **out_na,
212-
int coerce1, int coerce2, int *out_coerce) {
213-
int compatible = stringdtype_compatible_na(na1, na2);
214-
if (compatible == -1) {
215-
return -1;
216-
}
217-
*out_na = (na1 ? na1 : na2);
218-
*out_coerce = (coerce1 && coerce2);
219-
220-
return 0;
221-
}
222-
223212
/*
224213
* This is used to determine the correct dtype to return when dealing
225214
* with a mix of different dtypes (for example when creating an array
@@ -228,20 +217,18 @@ stringdtype_compatible_settings(PyObject *na1, PyObject *na2, PyObject **out_na,
228217
static PyArray_StringDTypeObject *
229218
common_instance(PyArray_StringDTypeObject *dtype1, PyArray_StringDTypeObject *dtype2)
230219
{
231-
int out_coerce = 1;
232220
PyObject *out_na_object = NULL;
233221

234-
if (stringdtype_compatible_settings(
235-
dtype1->na_object, dtype2->na_object, &out_na_object,
236-
dtype1->coerce, dtype2->coerce, &out_coerce) == -1) {
222+
if (stringdtype_compatible_na(
223+
dtype1->na_object, dtype2->na_object, &out_na_object) == -1) {
237224
PyErr_Format(PyExc_TypeError,
238225
"Cannot find common instance for incompatible dtypes "
239226
"'%R' and '%R'", (PyObject *)dtype1, (PyObject *)dtype2);
240227
return NULL;
241228
}
242229

243230
return (PyArray_StringDTypeObject *)new_stringdtype_instance(
244-
out_na_object, out_coerce);
231+
out_na_object, dtype1->coerce && dtype1->coerce);
245232
}
246233

247234
/*

numpy/_core/src/multiarray/stringdtype/dtype.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,7 @@ NPY_NO_EXPORT int
5050
_eq_comparison(int scoerce, int ocoerce, PyObject *sna, PyObject *ona);
5151

5252
NPY_NO_EXPORT int
53-
stringdtype_compatible_settings(PyObject *na1, PyObject *na2, PyObject **out_na,
54-
int coerce1, int coerce2, int *out_coerce);
55-
56-
NPY_NO_EXPORT int
57-
stringdtype_compatible_na(PyObject *na1, PyObject *na2);
53+
stringdtype_compatible_na(PyObject *na1, PyObject *na2, PyObject **out_na);
5854

5955
#ifdef __cplusplus
6056
}

numpy/_core/src/umath/stringdtype_ufuncs.cpp

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,11 @@ binary_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
246246
{
247247
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
248248
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
249-
int out_coerce = 1;
249+
int out_coerce = descr1->coerce && descr1->coerce;
250250
PyObject *out_na_object = NULL;
251251

252-
if (stringdtype_compatible_settings(
253-
descr1->na_object, descr2->na_object, &out_na_object,
254-
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
252+
if (stringdtype_compatible_na(
253+
descr1->na_object, descr2->na_object, &out_na_object) == -1) {
255254
return (NPY_CASTING)-1;
256255
}
257256

@@ -556,7 +555,7 @@ string_comparison_resolve_descriptors(
556555
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
557556
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
558557

559-
if (stringdtype_compatible_na(descr1->na_object, descr2->na_object) == -1) {
558+
if (stringdtype_compatible_na(descr1->na_object, descr2->na_object, NULL) == -1) {
560559
return (NPY_CASTING)-1;
561560
}
562561

@@ -786,12 +785,8 @@ string_findlike_resolve_descriptors(
786785
{
787786
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
788787
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
789-
int out_coerce = 1;
790-
PyObject *out_na_object = NULL;
791788

792-
if (stringdtype_compatible_settings(
793-
descr1->na_object, descr2->na_object, &out_na_object,
794-
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
789+
if (stringdtype_compatible_na(descr1->na_object, descr2->na_object, NULL) == -1) {
795790
return (NPY_CASTING)-1;
796791
}
797792

@@ -839,12 +834,8 @@ string_startswith_endswith_resolve_descriptors(
839834
{
840835
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
841836
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
842-
int out_coerce = 1;
843-
PyObject *out_na_object = NULL;
844837

845-
if (stringdtype_compatible_settings(
846-
descr1->na_object, descr2->na_object, &out_na_object,
847-
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
838+
if (stringdtype_compatible_na(descr1->na_object, descr2->na_object, NULL) == -1) {
848839
return (NPY_CASTING)-1;
849840
}
850841

@@ -1250,18 +1241,16 @@ replace_resolve_descriptors(struct PyArrayMethodObject_tag *NPY_UNUSED(method),
12501241
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
12511242
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
12521243
PyArray_StringDTypeObject *descr3 = (PyArray_StringDTypeObject *)given_descrs[2];
1253-
int out_coerce = 1;
1244+
int out_coerce = descr1->coerce && descr2->coerce && descr3->coerce;
12541245
PyObject *out_na_object = NULL;
12551246

1256-
if (stringdtype_compatible_settings(
1257-
descr1->na_object, descr2->na_object, &out_na_object,
1258-
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
1247+
if (stringdtype_compatible_na(
1248+
descr1->na_object, descr2->na_object, &out_na_object) == -1) {
12591249
return (NPY_CASTING)-1;
12601250
}
12611251

1262-
if (stringdtype_compatible_settings(
1263-
out_na_object, descr3->na_object, &out_na_object,
1264-
out_coerce, descr3->coerce, &out_coerce) == -1) {
1252+
if (stringdtype_compatible_na(
1253+
out_na_object, descr3->na_object, &out_na_object) == -1) {
12651254
return (NPY_CASTING)-1;
12661255
}
12671256

@@ -1525,12 +1514,11 @@ center_ljust_rjust_resolve_descriptors(
15251514
{
15261515
PyArray_StringDTypeObject *input_descr = (PyArray_StringDTypeObject *)given_descrs[0];
15271516
PyArray_StringDTypeObject *fill_descr = (PyArray_StringDTypeObject *)given_descrs[2];
1528-
int out_coerce = 1;
1517+
int out_coerce = input_descr->coerce && fill_descr->coerce;
15291518
PyObject *out_na_object = NULL;
15301519

1531-
if (stringdtype_compatible_settings(
1532-
input_descr->na_object, fill_descr->na_object, &out_na_object,
1533-
input_descr->coerce, fill_descr->coerce, &out_coerce) == -1) {
1520+
if (stringdtype_compatible_na(
1521+
input_descr->na_object, fill_descr->na_object, &out_na_object) == -1) {
15341522
return (NPY_CASTING)-1;
15351523
}
15361524

@@ -1835,12 +1823,11 @@ string_partition_resolve_descriptors(
18351823

18361824
PyArray_StringDTypeObject *descr1 = (PyArray_StringDTypeObject *)given_descrs[0];
18371825
PyArray_StringDTypeObject *descr2 = (PyArray_StringDTypeObject *)given_descrs[1];
1838-
int out_coerce = 1;
1826+
int out_coerce = descr1->coerce && descr2->coerce;
18391827
PyObject *out_na_object = NULL;
18401828

1841-
if (stringdtype_compatible_settings(
1842-
descr1->na_object, descr2->na_object, &out_na_object,
1843-
descr1->coerce, descr2->coerce, &out_coerce) == -1) {
1829+
if (stringdtype_compatible_na(
1830+
descr1->na_object, descr2->na_object, &out_na_object) == -1) {
18441831
return (NPY_CASTING)-1;
18451832
}
18461833

0 commit comments

Comments
 (0)
0