8000 Merge pull request #9054 from eric-wieser/fix-pep3118 · numpy/numpy@8618066 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8618066

Browse files
authored
Merge pull request #9054 from eric-wieser/fix-pep3118
BUG: Various fixes to _dtype_from_pep3118
2 parents 5e78b88 + a4f435c commit 8618066

File tree

2 files changed

+172
-127
lines changed

2 files changed

+172
-127
lines changed

numpy/core/_internal.py

Lines changed: 134 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -432,91 +432,119 @@ def _view_is_safe(oldtype, newtype):
432432
}
433433
_pep3118_standard_typechars = ''.join(_pep3118_standard_map.keys())
434434

435-
def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
436-
fields = {}
435+
def _dtype_from_pep3118(spec):
436+
437+
class Stream(object):
438+
def __init__(self, s):
439+
self.s = s
440+
self.byteorder = '@'
441+
442+
def advance(self, n):
443+
res = self.s[:n]
444+
self.s = self.s[n:]
445+
return res
446+
447+
def consume(self, c):
448+
if self.s[:len(c)] == c:
449+
self.advance(len(c))
450+
return True
451+
return False
452+
453+
def consume_until(self, c):
454+
if callable(c):
455+
i = 0
456+
while i < len(self.s) and not c(self.s[i]):
457+
i = i + 1
458+
return self.advance(i)
459+
else:
460+
i = self.s.index(c)
461+
res = self.advance(i)
462+
self.advance(len(c))
463+
return res
464+
465+
@property
466+
def next(self):
467+
return self.s[0]
468+
469+
def __bool__(self):
470+
return bool(self.s)
471+
__nonzero__ = __bool__
472+
473+
stream = Stream(spec)
474+
475+
dtype, align = __dtype_from_pep3118(stream, is_subdtype=False)
476+
return dtype
477+
478+
def __dtype_from_pep3118(stream, is_subdtype):
479+
field_spec = dict(
480+
names=[],
481+
B422 formats=[],
482+
offsets=[],
483+
itemsize=0
484+
)
437485
offset = 0
438-
explicit_name = False
439-
this_explicit_name = False
440486
common_alignment = 1
441487
is_padding = False
442488

443-
dummy_name_index = [0]
444-
445-
def next_dummy_name():
446-
dummy_name_index[0] += 1
447-
448-
def get_dummy_name():
449-
while True:
450-
name = 'f%d' % dummy_name_index[0]
451-
if name not in fields:
452-
return name
453-
next_dummy_name()
454-
455489
# Parse spec
456-
while spec:
490+
while stream:
457491
value = None
458492

459493
# End of structure, bail out to upper level
460-
if spec[0] == '}':
461-
spec = spec[1:]
494+
if stream.consume('}'):
462495
break
463496

464497
# Sub-arrays (1)
465498
shape = None
466-
if spec[0] == '(':
467-
j = spec.index(')')
468-
shape = tuple(map(int, spec[1:j].split(',')))
469-
spec = spec[j+1:]
499+
if stream.consume('('):
500+
shape = stream.consume_until(')')
501+
shape = tuple(map(int, shape.split(',')))
470502

471503
# Byte order
472-
if spec[0] in ('@', '=', '<', '>', '^', '!'):
473-
byteorder = spec[0]
504+
if stream.next in ('@', '=', '<', '>', '^', '!'):
505+
byteorder = stream.advance(1)
474506
if byteorder == '!':
475507
byteorder = '>'
476-
spec = spec[1:]
508+
stream.byteorder = byteorder
477509

478510
# Byte order characters also control native vs. standard type sizes
479-
if byteorder in ('@', '^'):
511+
if stream.byteorder in ('@', '^'):
480512
type_map = _pep3118_native_map
481513
type_map_chars = _pep3118_native_typechars
482514
else:
483515
type_map = _pep3118_standard_map
484516
type_map_chars = _pep3118_standard_typechars
485517

486518
# Item sizes
487-
itemsize = 1
488-
if spec[0].isdigit():
489-
j = 1
490-
for j in range(1, len(spec)):
491-
if not spec[j].isdigit():
492-
break
493-
itemsize = int(spec[:j])
494-
spec = spec[j:]
519+
itemsize_str = stream.consume_until(lambda c: not c.isdigit())
520+
if itemsize_str:
521+
itemsize = int(itemsize_str)
522+
else:
523+
itemsize = 1
495524

496525
# Data types
497526
is_padding = False
498527

499-
if spec[:2] == 'T{':
500-
value, spec, align, next_byteorder = _dtype_from_pep3118(
501-
spec[2:], byteorder=byteorder, is_subdtype=True)
502-
elif spec[0] in type_map_chars:
503-
next_byteorder = byteorder
504-
if spec[0] == 'Z':
505-
j = 2
528+
if stream.consume('T{'):
529+
value, align = __dtype_from_pep3118(
530+
stream, is_subdtype=True)
531+
elif stream.next in type_map_chars:
532+
if stream.next == 'Z':
533+
typechar = stream.advance(2)
506534
else:
507-
j = 1
508-
typechar = spec[:j]
509-
spec = spec[j:]
535+
typechar = stream.advance(1)
536+
510537
is_padding = (typechar == 'x')
511538
dtypechar = type_map[typechar]
512539
if dtypechar in 'USV':
513540
dtypechar += '%d' % itemsize
514541
itemsize = 1
515-
numpy_byteorder = {'@': '=', '^': '='}.get(byteorder, byteorder)
542+
numpy_byteorder = {'@': '=', '^': '='}.get(
543+
stream.byteorder, stream.byteorder)
516544
value = dtype(numpy_byteorder + dtypechar)
517545
align = value.alignment
518546
else:
519-
raise ValueError("Unknown PEP 3118 data type specifier %r" % spec)
547+
raise ValueError("Unknown PEP 3118 data type specifier %r" % stream.s)
520548

521549
#
522550
# Native alignment may require padding
@@ -525,7 +553,7 @@ def get_dummy_name():
525553
# that the start of the array is *already* aligned.
526554
#
527555
extra_offset = 0
528-
if byteorder == '@':
556+
if stream.byteorder == '@':
529557
start_padding = (-offset) % align
530558
intra_padding = (-value.itemsize) % align
531559

@@ -541,8 +569,7 @@ def get_dummy_name():
541569
extra_offset += intra_padding
542570

543571
# Update common alignment
544-
common_alignment = (align*common_alignment
545-
/ _gcd(align, common_alignment))
572+
common_alignment = _lcm(align, common_alignment)
546573

547574
# Convert itemsize to sub-array
548575
if itemsize != 1:
@@ -553,79 +580,77 @@ def get_dummy_name():
553580
value = dtype((value, shape))
554581

555582
# Field name
556-
this_explicit_name = False
557-
if spec and spec.startswith(':'):
558-
i = spec[1:].index(':') + 1
559-
name = spec[1:i]
560-
spec = spec[i+1:]
561-
explicit_name = True
562-
this_explicit_name = True
583+
if stream.consume(':'):
584+
name = stream.consume_until(':')
563585
else:
564-
name = get_dummy_name()
586+
name = None
565587

566-
if not is_padding or this_explicit_name:
567-
if name in fields:
588+
if not (is_padding and name is None):
589+
if name is not None and name in field_spec['names']:
568590
raise RuntimeError("Duplicate field name '%s' in PEP3118 format"
569591
% name)
570-
fields[name] = (value, offset)
571-
if not this_explicit_name:
572-
next_dummy_name()
573-
574-
byteorder = next_byteorder
592+
field_spec['names'].append(name)
593+
field_spec['formats'].append(value)
594+
field_spec['offsets'].append(offset)
575595

576596
offset += value.itemsize
577597
offset += extra_offset
578598

579-
# Check if this was a simple 1-item type
580-
if (len(fields) == 1 and not explicit_name and
581-
fields['f0'][1] == 0 and not is_subdtype):
582-
ret = fields['f0'][0]
583-
else:
584-
ret = dtype(fields)
599+
field_spec['itemsize'] = offset
585600

586-
# Trailing padding must be explicitly added
587-
padding = offset - ret.itemsize
588-
if byteorder == '@':
589-
padding += (-offset) % common_alignment
590-
if is_padding and not this_explicit_name:
591-
ret = _add_trailing_padding(ret, padding)
601+
# extra final padding for aligned types
602+
if stream.byteorder == '@':
603+
field_spec['itemsize'] += (-offset) % common_alignment
592604

593-
# Finished
594-
if is_subdtype:
595-
return ret, spec, common_alignment, byteorder
605+
# Check if this was a simple 1-item type, and unwrap it
606+
if (field_spec['names'] == [None]
607+
and field_spec['offsets'][0] == 0
608+
and field_spec['itemsize'] == field_spec['formats'][0].itemsize
609+
and not is_subdtype):
610+
ret = field_spec['formats'][0]
596611
else:
597-
return ret
612+
_fix_names(field_spec)
613+
ret = dtype(field_spec)
614+
615+
# Finished
616+
return ret, common_alignment
617+
618+
def _fix_names(field_spec):
619+
""" Replace names which are None with the next unused f%d name """
620+
names = field_spec['names']
621+
for i, name in enumerate(names):
622+
if name is not None:
623+
continue
624+
625+
j = 0
626+
while True:
627+
name = 'f{}'.format(j)
628+
if name not in names:
629+
break
630+
j = j + 1
631+
names[i] = name
598632

599633
def _add_trailing_padding(value, padding):
600634
"""Inject the specified number of padding bytes at the end of a dtype"""
601635
if value.fields is None:
602-
vfields = {'f0': (value, 0)}
603-
else:
604-
vfields = dict(value.fields)
605-
606-
if (value.names and value.names[-1] == '' and
607-
value[''].char == 'V'):
608-
# A trailing padding field is already present
609-
vfields[''] = ('V%d' % (vfields[''][0].itemsize + padding),
610-
vfields[''][1])
611-
value = dtype(vfields)
636+
field_spec = dict(
637+
names=['f0'],
638+
formats=[value],
639+
offsets=[0],
640+
itemsize=value.itemsize
641+
)
612642
else:
613-
# Get a free name for the padding field
614-
j = 0
615-
while True:
616-
name = 'pad%d' % j
617-
if name not in vfields:
618-
vfields[name] = ('V%d' % padding, value.itemsize)
619-
break
620-
j += 1
643+
fields = value.fields
644+
names = value.names
645+
field_spec = dict(
646+
names=names,
647+
formats=[fields[name][0] for name in names],
648+
offsets=[fields[name][1] for name in names],
649+
itemsize=value.itemsize
650+
)
621651

622-
value = dtype(vfields)
623-
if '' not in vfields:
624-
# Strip out the name of the padding field
625-
names = list(value.names)
626-
names[-1] = ''
627-
value.names = tuple(names)
628-
return value
652+
field_spec['itemsize'] += padding
653+
return dtype(field_spec)
629654

630655
def _prod(a):
631656
p = 1
@@ -639,6 +664,9 @@ def _gcd(a, b):
639664
a, b = b, a % b
640665
return a
641666

667+
def _lcm(a, b):
668+
return a // _gcd(a, b) * b
669+
642670
# Exception used in shares_memory()
643671
class TooHardError(RuntimeError):
644672
pass

0 commit comments

Comments
 (0)
0