diff --git a/numpy/core/defchararray.py b/numpy/core/defchararray.py index 6d0a0add58da..0a8c7bbec398 100644 --- a/numpy/core/defchararray.py +++ b/numpy/core/defchararray.py @@ -22,6 +22,7 @@ from .numeric import ndarray, compare_chararrays from .numeric import array as narray from numpy.core.multiarray import _vec_string +from numpy.core.overrides import array_function_dispatch from numpy.compat import asbytes, long import numpy @@ -95,6 +96,11 @@ def _get_num_chars(a): return a.itemsize +def _binary_op_dispatcher(x1, x2): + return (x1, x2) + + +@array_function_dispatch(_binary_op_dispatcher) def equal(x1, x2): """ Return (x1 == x2) element-wise. @@ -119,6 +125,8 @@ def equal(x1, x2): """ return compare_chararrays(x1, x2, '==', True) + +@array_function_dispatch(_binary_op_dispatcher) def not_equal(x1, x2): """ Return (x1 != x2) element-wise. @@ -143,6 +151,8 @@ def not_equal(x1, x2): """ return compare_chararrays(x1, x2, '!=', True) + +@array_function_dispatch(_binary_op_dispatcher) def greater_equal(x1, x2): """ Return (x1 >= x2) element-wise. @@ -168,6 +178,8 @@ def greater_equal(x1, x2): """ return compare_chararrays(x1, x2, '>=', True) + +@array_function_dispatch(_binary_op_dispatcher) def less_equal(x1, x2): """ Return (x1 <= x2) element-wise. @@ -192,6 +204,8 @@ def less_equal(x1, x2): """ return compare_chararrays(x1, x2, '<=', True) + +@array_function_dispatch(_binary_op_dispatcher) def greater(x1, x2): """ Return (x1 > x2) element-wise. @@ -216,6 +230,8 @@ def greater(x1, x2): """ return compare_chararrays(x1, x2, '>', True) + +@array_function_dispatch(_binary_op_dispatcher) def less(x1, x2): """ Return (x1 < x2) element-wise. @@ -240,6 +256,12 @@ def less(x1, x2): """ return compare_chararrays(x1, x2, '<', True) + +def _unary_op_dispatcher(a): + return (a,) + + +@array_function_dispatch(_unary_op_dispatcher) def str_len(a): """ Return len(a) element-wise. @@ -259,6 +281,8 @@ def str_len(a): """ return _vec_string(a, integer, '__len__') + +@array_function_dispatch(_binary_op_dispatcher) def add(x1, x2): """ Return element-wise string concatenation for two arrays of str or unicode. @@ -285,6 +309,12 @@ def add(x1, x2): dtype = _use_unicode(arr1, arr2) return _vec_string(arr1, (dtype, out_size), '__add__', (arr2,)) + +def _multiply_dispatcher(a, i): + return (a,) + + +@array_function_dispatch(_multiply_dispatcher) def multiply(a, i): """ Return (a * i), that is string multiple concatenation, @@ -313,6 +343,12 @@ def multiply(a, i): return _vec_string( a_arr, (a_arr.dtype.type, out_size), '__mul__', (i_arr,)) + +def _mod_dispatcher(a, values): + return (a, values) + + +@array_function_dispatch(_mod_dispatcher) def mod(a, values): """ Return (a % i), that is pre-Python 2.6 string formatting @@ -339,6 +375,8 @@ def mod(a, values): return _to_string_or_unicode_array( _vec_string(a, object_, '__mod__', (values,))) + +@array_function_dispatch(_unary_op_dispatcher) def capitalize(a): """ Return a copy of `a` with only the first character of each element @@ -377,6 +415,11 @@ def capitalize(a): return _vec_string(a_arr, a_arr.dtype, 'capitalize') +def _center_dispatcher(a, width, fillchar=None): + return (a,) + + +@array_function_dispatch(_center_dispatcher) def center(a, width, fillchar=' '): """ Return a copy of `a` with its elements centered in a string of @@ -413,6 +456,11 @@ def center(a, width, fillchar=' '): a_arr, (a_arr.dtype.type, size), 'center', (width_arr, fillchar)) +def _count_dispatcher(a, sub, start=None, end=None): + return (a,) + + +@array_function_dispatch(_count_dispatcher) def count(a, sub, start=0, end=None): """ Returns an array with the number of non-overlapping occurrences of @@ -459,6 +507,11 @@ def count(a, sub, start=0, end=None): return _vec_string(a, integer, 'count', [sub, start] + _clean_args(end)) +def _code_dispatcher(a, encoding=None, errors=None): + return (a,) + + +@array_function_dispatch(_code_dispatcher) def decode(a, encoding=None, errors=None): """ Calls `str.decode` element-wise. @@ -505,6 +558,7 @@ def decode(a, encoding=None, errors=None): _vec_string(a, object_, 'decode', _clean_args(encoding, errors))) +@array_function_dispatch(_code_dispatcher) def encode(a, encoding=None, errors=None): """ Calls `str.encode` element-wise. @@ -540,6 +594,11 @@ def encode(a, encoding=None, errors=None): _vec_string(a, object_, 'encode', _clean_args(encoding, errors))) +def _endswith_dispatcher(a, suffix, start=None, end=None): + return (a,) + + +@array_function_dispatch(_endswith_dispatcher) def endswith(a, suffix, start=0, end=None): """ Returns a boolean array which is `True` where the string element @@ -584,6 +643,11 @@ def endswith(a, suffix, start=0, end=None): a, bool_, 'endswith', [suffix, start] + _clean_args(end)) +def _expandtabs_dispatcher(a, tabsize=None): + return (a,) + + +@array_function_dispatch(_expandtabs_dispatcher) def expandtabs(a, tabsize=8): """ Return a copy of each string element where all tab characters are @@ -619,6 +683,7 @@ def expandtabs(a, tabsize=8): _vec_string(a, object_, 'expandtabs', (tabsize,))) +@array_function_dispatch(_count_dispatcher) def find(a, sub, start=0, end=None): """ For each element, return the lowest index in the string where @@ -654,6 +719,7 @@ def find(a, sub, start=0, end=None): a, integer, 'find', [sub, start] + _clean_args(end)) +@array_function_dispatch(_count_dispatcher) def index(a, sub, start=0, end=None): """ Like `find`, but raises `ValueError` when the substring is not found. @@ -681,6 +747,8 @@ def index(a, sub, start=0, end=None): return _vec_string( a, integer, 'index', [sub, start] + _clean_args(end)) + +@array_function_dispatch(_unary_op_dispatcher) def isalnum(a): """ Returns true for each element if all characters in the string are @@ -705,6 +773,8 @@ def isalnum(a): """ return _vec_string(a, bool_, 'isalnum') + +@array_function_dispatch(_unary_op_dispatcher) def isalpha(a): """ Returns true for each element if all characters in the string are @@ -729,6 +799,8 @@ def isalpha(a): """ return _vec_string(a, bool_, 'isalpha') + +@array_function_dispatch(_unary_op_dispatcher) def isdigit(a): """ Returns true for each element if all characters in the string are @@ -753,6 +825,8 @@ def isdigit(a): """ return _vec_string(a, bool_, 'isdigit') + +@array_function_dispatch(_unary_op_dispatcher) def islower(a): """ Returns true for each element if all cased characters in the @@ -778,6 +852,8 @@ def islower(a): """ return _vec_string(a, bool_, 'islower') + +@array_function_dispatch(_unary_op_dispatcher) def isspace(a): """ Returns true for each element if there are only whitespace @@ -803,6 +879,8 @@ def isspace(a): """ return _vec_string(a, bool_, 'isspace') + +@array_function_dispatch(_unary_op_dispatcher) def istitle(a): """ Returns true for each element if the element is a titlecased @@ -827,6 +905,8 @@ def istitle(a): """ return _vec_string(a, bool_, 'istitle') + +@array_function_dispatch(_unary_op_dispatcher) def isupper(a): """ Returns true for each element if all cased characters in the @@ -852,6 +932,12 @@ def isupper(a): """ return _vec_string(a, bool_, 'isupper') + +def _join_dispatcher(sep, seq): + return (sep, seq) + + +@array_function_dispatch(_join_dispatcher) def join(sep, seq): """ Return a string which is the concatenation of the strings in the @@ -877,6 +963,12 @@ def join(sep, seq): _vec_string(sep, object_, 'join', (seq,))) + +def _just_dispatcher(a, width, fillchar=None): + return (a,) + + +@array_function_dispatch(_just_dispatcher) def ljust(a, width, fillchar=' '): """ Return an array with the elements of `a` left-justified in a @@ -912,6 +1004,7 @@ def ljust(a, width, fillchar=' '): a_arr, (a_arr.dtype.type, size), 'ljust', (width_arr, fillchar)) +@array_function_dispatch(_unary_op_dispatcher) def lower(a): """ Return an array with the elements converted to lowercase. @@ -948,6 +1041,11 @@ def lower(a): return _vec_string(a_arr, a_arr.dtype, 'lower') +def _strip_dispatcher(a, chars=None): + return (a,) + + +@array_function_dispatch(_strip_dispatcher) def lstrip(a, chars=None): """ For each element in `a`, return a copy with the leading characters @@ -1005,6 +1103,11 @@ def lstrip(a, chars=None): return _vec_string(a_arr, a_arr.dtype, 'lstrip', (chars,)) +def _partition_dispatcher(a, sep): + return (a,) + + +@array_function_dispatch(_partition_dispatcher) def partition(a, sep): """ Partition each element in `a` around `sep`. @@ -1040,6 +1143,11 @@ def partition(a, sep): _vec_string(a, object_, 'partition', (sep,))) +def _replace_dispatcher(a, old, new, count=None): + return (a,) + + +@array_function_dispatch(_replace_dispatcher) def replace(a, old, new, count=None): """ For each element in `a`, return a copy of the string with all @@ -1072,6 +1180,7 @@ def replace(a, old, new, count=None): a, object_, 'replace', [old, new] + _clean_args(count))) +@array_function_dispatch(_count_dispatcher) def rfind(a, sub, start=0, end=None): """ For each element in `a`, return the highest index in the string @@ -1104,6 +1213,7 @@ def rfind(a, sub, start=0, end=None): a, integer, 'rfind', [sub, start] + _clean_args(end)) +@array_function_dispatch(_count_dispatcher) def rindex(a, sub, start=0, end=None): """ Like `rfind`, but raises `ValueError` when the substring `sub` is @@ -1133,6 +1243,7 @@ def rindex(a, sub, start=0, end=None): a, integer, 'rindex', [sub, start] + _clean_args(end)) +@array_function_dispatch(_just_dispatcher) def rjust(a, width, fillchar=' '): """ Return an array with the elements of `a` right-justified in a @@ -1168,6 +1279,7 @@ def rjust(a, width, fillchar=' '): a_arr, (a_arr.dtype.type, size), 'rjust', (width_arr, fillchar)) +@array_function_dispatch(_partition_dispatcher) def rpartition(a, sep): """ Partition (split) each element around the right-most separator. @@ -1203,6 +1315,11 @@ def rpartition(a, sep): _vec_string(a, object_, 'rpartition', (sep,))) +def _split_dispatcher(a, sep=None, maxsplit=None): + return (a,) + + +@array_function_dispatch(_split_dispatcher) def rsplit(a, sep=None, maxsplit=None): """ For each element in `a`, return a list of the words in the @@ -1240,6 +1357,11 @@ def rsplit(a, sep=None, maxsplit=None): a, object_, 'rsplit', [sep] + _clean_args(maxsplit)) +def _strip_dispatcher(a, chars=None): + return (a,) + + +@array_function_dispatch(_strip_dispatcher) def rstrip(a, chars=None): """ For each element in `a`, return a copy with the trailing @@ -1284,6 +1406,7 @@ def rstrip(a, chars=None): return _vec_string(a_arr, a_arr.dtype, 'rstrip', (chars,)) +@array_function_dispatch(_split_dispatcher) def split(a, sep=None, maxsplit=None): """ For each element in `a`, return a list of the words in the @@ -1318,6 +1441,11 @@ def split(a, sep=None, maxsplit=None): a, object_, 'split', [sep] + _clean_args(maxsplit)) +def _splitlines_dispatcher(a, keepends=None): + return (a,) + + +@array_function_dispatch(_splitlines_dispatcher) def splitlines(a, keepends=None): """ For each element in `a`, return a list of the lines in the @@ -1347,6 +1475,11 @@ def splitlines(a, keepends=None): a, object_, 'splitlines', _clean_args(keepends)) +def _startswith_dispatcher(a, prefix, start=None, end=None): + return (a,) + + +@array_function_dispatch(_startswith_dispatcher) def startswith(a, prefix, start=0, end=None): """ Returns a boolean array which is `True` where the string element @@ -1378,6 +1511,7 @@ def startswith(a, prefix, start=0, end=None): a, bool_, 'startswith', [prefix, start] + _clean_args(end)) +@array_function_dispatch(_strip_dispatcher) def strip(a, chars=None): """ For each element in `a`, return a copy with the leading and @@ -1426,6 +1560,7 @@ def strip(a, chars=None): return _vec_string(a_arr, a_arr.dtype, 'strip', _clean_args(chars)) +@array_function_dispatch(_unary_op_dispatcher) def swapcase(a): """ Return element-wise a copy of the string with @@ -1463,6 +1598,7 @@ def swapcase(a): return _vec_string(a_arr, a_arr.dtype, 'swapcase') +@array_function_dispatch(_unary_op_dispatcher) def title(a): """ Return element-wise title cased version of string or unicode. @@ -1502,6 +1638,11 @@ def title(a): return _vec_string(a_arr, a_arr.dtype, 'title') +def _translate_dispatcher(a, table, deletechars=None): + return (a,) + + +@array_function_dispatch(_translate_dispatcher) def translate(a, table, deletechars=None): """ For each element in `a`, return a copy of the string where all @@ -1538,6 +1679,7 @@ def translate(a, table, deletechars=None): a_arr, a_arr.dtype, 'translate', [table] + _clean_args(deletechars)) +@array_function_dispatch(_unary_op_dispatcher) def upper(a): """ Return an array with the elements converted to uppercase. @@ -1574,6 +1716,11 @@ def upper(a): return _vec_string(a_arr, a_arr.dtype, 'upper') +def _zfill_dispatcher(a, width): + return (a,) + + +@array_function_dispatch(_zfill_dispatcher) def zfill(a, width): """ Return the numeric string left-filled with zeros @@ -1604,6 +1751,7 @@ def zfill(a, width): a_arr, (a_arr.dtype.type, size), 'zfill', (width_arr,)) +@array_function_dispatch(_unary_op_dispatcher) def isnumeric(a): """ For each element, return True if there are only numeric @@ -1635,6 +1783,7 @@ def isnumeric(a): return _vec_string(a, bool_, 'isnumeric') +@array_function_dispatch(_unary_op_dispatcher) def isdecimal(a): """ For each element, return True if there are only decimal