8000 BUG: Various fixes to _dtype_from_pep3118 by eric-wieser · Pull Request #9054 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

BUG: Various fixes to _dtype_from_pep3118 #9054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 9, 2017
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
MAINT: refactor _dtype_from_pep3118 in terms of a stream
  • Loading branch information
eric-wieser committed May 5, 2017
commit 7f6c95fe1d3873fae6b83530e9ee7a3f5f357504
122 changes: 79 additions & 43 deletions numpy/core/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,49 @@ def _view_is_safe(oldtype, newtype):
}
_pep3118_standard_typechars = ''.join(_pep3118_standard_map.keys())

def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
def _dtype_from_pep3118(spec):

class Stream(object):
def __init__(self, s):
self.s = s
self.byteorder = '@'

def advance(self, n):
res = self.s[:n]
self.s = self.s[n:]
return res

def consume(self, c):
if self.s[:len(c)] == c:
self.advance(len(c))
return True
return False

def consume_until(self, c):
if callable(c):
i = 0
while i < len(self.s) and not c(self.s[i]):
i = i + 1
return self.advance(i)
else:
i = self.s.index(c)
res = self.advance(i)
self.advance(len(c))
return res

@property
def next(self):
return self.s[0]

def __bool__(self):
return bool(self.s)
__nonzero__ = __bool__

stream = Stream(spec)

return __dtype_from_pep3118(stream, is_subdtype=False)

def __dtype_from_pep3118(stream, is_subdtype):
fields = {}
offset = 0
explicit_name = False
Expand All @@ -429,6 +471,7 @@ def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):

dummy_name_index = [0]


def next_dummy_name():
dummy_name_index[0] += 1

Expand All @@ -439,71 +482,66 @@ def get_dummy_name():
return name
next_dummy_name()


# Parse spec
while spec:
while stream:
value = None

# End of structure, bail out to upper level
if spec[0] == '}':
spec = spec[1:]
if stream.consume('}'):
break

# Sub-arrays (1)
shape = None
if spec[0] == '(':
j = spec.index(')')
shape = tuple(map(int, spec[1:j].split(',')))
spec = spec[j+1:]
if stream.consume('('):
shape = stream.consume_until(')')
shape = tuple(map(int, shape.split(',')))

# Byte order
if spec[0] in ('@', '=', '<', '>', '^', '!'):
byteorder = spec[0]
if stream.next in ('@', '=', '<', '>', '^', '!'):
byteorder = stream.advance(1)
if byteorder == '!':
byteorder = '>'
spec = spec[1:]
stream.byteorder = byteorder

# Byte order characters also control native vs. standard type sizes
if byteorder in ('@', '^'):
if stream.byteorder in ('@', '^'):
type_map = _pep3118_native_map
type_map_chars = _pep3118_native_typechars
else:
type_map = _pep3118_standard_map
type_map_chars = _pep3118_standard_typechars

# Item sizes
itemsize = 1
if spec[0].isdigit():
j = 1
for j in range(1, len(spec)):
if not spec[j].isdigit():
break
itemsize = int(spec[:j])
spec = spec[j:]
itemsize_str = stream.consume_until(lambda c: not c.isdigit())
if itemsize_str:
itemsize = int(itemsize_str)
else:
itemsize = 1

# Data types
is_padding = False

if spec[:2] == 'T{':
value, spec, align, next_byteorder = _dtype_from_pep3118(
spec[2:], byteorder=byteorder, is_subdtype=True)
elif spec[0] in type_map_chars:
next_byteorder = byteorder
if spec[0] == 'Z':
j = 2
if stream.consume('T{'):
value, align = __dtype_from_pep3118(
stream, is_subdtype=True)
elif stream.next in type_map_chars:
if stream.next == 'Z':
typechar = stream.advance(2)
else:
j = 1
typechar = spec[:j]
spec = spec[j:]
typechar = stream.advance(1)

is_padding = (typechar == 'x')
dtypechar = type_map[typechar]
if dtypechar in 'USV':
dtypechar += '%d' % itemsize
itemsize = 1
numpy_byteorder = {'@': '=', '^': '='}.get(byteorder, byteorder)
numpy_byteorder = {'@': '=', '^': '='}.get(
stream.byteorder, stream.byteorder)
value = dtype(numpy_byteorder + dtypechar)
align = value.alignment
else:
raise ValueError("Unknown PEP 3118 data type specifier %r" % spec)
raise ValueError("Unknown PEP 3118 data type specifier %r" % stream.s)

#
# Native alignment may require padding
Expand All @@ -512,7 +550,7 @@ def get_dummy_name():
# that the start of the array is *already* aligned.
#
extra_offset = 0
if byteorder == '@':
if stream.byteorder == '@':
start_padding = (-offset) % align
intra_padding = (-value.itemsize) % align

Expand All @@ -528,8 +566,7 @@ def get_dummy_name():
extra_offset += intra_padding

# Update common alignment
common_alignment = (align*common_alignment
/ _gcd(align, common_alignment))
common_alignment = _lcm(align, common_alignment)

# Convert itemsize to sub-array
if itemsize != 1:
Expand All @@ -541,10 +578,8 @@ def get_dummy_name():

# Field name
this_explicit_name = False
if spec and spec.startswith(':'):
i = spec[1:].index(':') + 1
name = spec[1:i]
spec = spec[i+1:]
if stream.consume(':'):
name = stream.consume_until(':')
explicit_name = True
this_explicit_name = True
else:
Expand All @@ -558,8 +593,6 @@ def get_dummy_name():
if not this_explicit_name:
next_dummy_name()

byteorder = next_byteorder

offset += value.itemsize
offset += extra_offset

Expand All @@ -572,14 +605,14 @@ def get_dummy_name():

# Trailing padding must be explicitly added
padding = offset - ret.itemsize
if byteorder == '@':
if stream.byteorder == '@':
padding += (-offset) % common_alignment
if is_padding and not this_explicit_name:
ret = _add_trailing_padding(ret, padding)

# Finished
if is_subdtype:
return ret, spec, common_alignment, byteorder
return ret, common_alignment
else:
return ret

Expand Down Expand Up @@ -626,6 +659,9 @@ def _gcd(a, b):
a, b = b, a % b
return a

def _lcm(a, b):
return a / _gcd(a, b) * b

# Exception used in shares_memory()
class TooHardError(RuntimeError):
pass
Expand Down
0