8000 ENH: ticket #1675, Add scalar support for the format() function. · certik/numpy@5fe46bd · GitHub
[go: up one dir, main page]

Skip to content

Commit 5fe46bd

Browse files
mwiebecharris 8000
authored andcommitted
ENH: ticket numpy#1675, Add scalar support for the format() function.
Backport of 88e8c15.
1 parent a292a72 commit 5fe46bd

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

numpy/core/src/multiarray/scalartypes.c.src

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,79 @@ gentype_repr(PyObject *self)
361361
return ret;
362362
}
363363

364+
#if PY_VERSION_HEX >= 0x02060000
365+
/*
366+
* The __format__ method for PEP 3101.
367+
*/
368+
static PyObject *
369+
gentype_format(PyObject *self, PyObject *args)
370+
{
371+
PyObject *format_spec;
372+
PyObject *obj, *ret;
373+
374+
#if defined(NPY_PY3K)
375+
if (!PyArg_ParseTuple(args, "U:__format__", &format_spec)) {
376+
return NULL;
377+
}
378+
#else
379+
if (!PyArg_ParseTuple(args, &quo 10000 t;O:__format__", &format_spec)) {
380+
return NULL;
381+
}
382+
383+
if (!PyUnicode_Check(format_spec) && !PyString_Check(format_spec)) {
384+
PyErr_SetString(PyExc_TypeError,
385+
"format must be a string");
386+
return NULL;
387+
}
388+
#endif
389+
390+
/*
391+
* Convert to an appropriate Python type and call its format.
392+
* TODO: For some types, like long double, this isn't right,
393+
* because it throws away precision.
394+
*/
395+
if (Py_TYPE(self) == &PyBoolArrType_Type) {
396+
obj = PyBool_FromLong(((PyBoolScalarObject *)self)->obval);
397+
}
398+
else if (PyArray_IsScalar(self, Integer)) {
399+
#if defined(NPY_PY3K)
400+
obj = Py_TYPE(self)->tp_as_number->nb_int(self);
401+
#else
402+
obj = Py_TYPE(self)->tp_as_number->nb_long(self);
403+
#endif
404+
}
405+
else if (PyArray_IsScalar(self, Floating)) {
406+
obj = Py_TYPE(self)->tp_as_number->nb_float(self);
407+
}
408+
else if (PyArray_IsScalar(self, ComplexFloating)) {
409+
double val[2];
410+
PyArray_Descr *dtype = PyArray_DescrFromScalar(self);
411+
412+
if (dtype == NULL) {
413+
return NULL;
414+
}
415+
if (PyArray_CastScalarDirect(self, dtype, &val[0], NPY_CDOUBLE) < 0) {
416+
Py_DECREF(dtype);
417+
return NULL;
418+
}
419+
obj = PyComplex_FromDoubles(val[0], val[1]);
420+
Py_DECREF(dtype);
421+
}
422+
else {
423+
obj = self;
424+
Py_INCREF(obj);
425+
}
426+
427+
if (obj == NULL) {
428+
return NULL;
429+
}
430+
431+
ret = PyObject_Format(obj, format_spec);
432+
Py_DECREF(obj);
433+
return ret;
434+
}
435+
#endif
436+
364437
#ifdef FORCE_NO_LONG_DOUBLE_FORMATTING
365438
#undef NPY_LONGDOUBLE_FMT
366439
#define NPY_LONGDOUBLE_FMT NPY_DOUBLE_FMT
@@ -1682,6 +1755,13 @@ static PyMethodDef gentype_methods[] = {
16821755
{"__round__",
16831756
(PyCFunction)gentype_round,
16841757
METH_VARARGS | METH_KEYWORDS, NULL},
1758+
#endif
1759+
#if PY_VERSION_HEX >= 0x02060000
1760+
/* For the format function */
1761+
{"__format__",
1762+
gentype_format,
1763+
METH_VARARGS,
1764+
"NumPy array scalar formatter"},
16851765
#endif
16861766
{"setflags",
16871767
(PyCFunction)gentype_setflags,

numpy/core/tests/test_print.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,32 @@ def test_complex_type_print():
198198
for t in [np.complex64, np.cdouble, np.clongdouble] :
199199
yield check_complex_type_print, t
200200

201+
@dec.skipif(sys.version_info < (2,6))
202+
def test_scalar_format():
203+
"""Test the str.format method with NumPy scalar types"""
204+
tests = [('{0}', True, np.bool_),
205+
('{0}', False, np.bool_),
206+
('{0:d}', 130, np.uint8),
207+
('{0:d}', 50000, np.uint16),
208+
('{0:d}', 3000000000, np.uint32),
209+
('{0:d}', 15000000000000000000, np.uint64),
210+
('{0:d}', -120, np.int8),
211+
('{0:d}', -30000, np.int16),
212+
('{0:d}', -2000000000, np.int32),
213+
('{0:d}', -7000000000000000000, np.int64),
214+
('{0:g}', 1.5, np.float16),
215+
('{0:g}', 1.5, np.float32),
216+
('{0:g}', 1.5, np.float64),
217+
('{0:g}', 1.5, np.longdouble),
218+
('{0:g}', 1.5+0.5j, np.complex64),
219+
('{0:g}', 1.5+0.5j, np.complex128),
220+
('{0:g}', 1.5+0.5j, np.clongdouble)]
221+
222+
for (fmat, val, valtype) in tests:
223+
assert_equal(fmat.format(val), fmat.format(valtype(val)),
224+
"failed with val %s, type %s" % (val, valtype))
225+
226+
201227
# Locale tests: scalar types formatting should be independent of the locale
202228
def in_foreign_locale(func):
203229
"""

0 commit comments

Comments
 (0)
0