4040import os
4141
4242from . import numeric as sb
43- from .defchararray import chararray
4443from . import numerictypes as nt
4544from numpy .compat import isfileobj , bytes , long
4645
@@ -238,17 +237,15 @@ def __getattribute__(self, attr):
238237 res = fielddict .get (attr , None )
239238 if res :
240239 obj = self .getfield (* res [:2 ])
241- # if it has fields return a recarray,
242- # if it's a string ('SU') return a chararray
240+ # if it has fields return a record,
243241 # otherwise return the object
244242 try :
245243 dt = obj .dtype
246244 except AttributeError :
245+ #happens if field is Object type
247246 return obj
248247 if dt .fields :
249- return obj .view (obj .__class__ )
250- if dt .char in 'SU' :
251- return obj .view (chararray )
248+ return obj .view ((record , obj .dtype .descr ))
252249 return obj
253250 else :
254251 raise AttributeError ("'record' object has no "
@@ -418,29 +415,37 @@ def __new__(subtype, shape, dtype=None, buf=None, offset=0, strides=None,
418415 return self
419416
420417 def __getattribute__ (self , attr ):
418+ # See if ndarray has this attr, and return it if so. (note that this
419+ # means a field with the same name as an ndarray attr cannot be
420+ # accessed by attribute).
421421 try :
422422 return object .__getattribute__ (self , attr )
423423 except AttributeError : # attr must be a fieldname
424424 pass
425+
426+ # look for a field with this name
425427 fielddict = ndarray .__getattribute__ (self , 'dtype' ).fields
426428 try :
427429 res = fielddict [attr ][:2 ]
428430 except (TypeError , KeyError ):
429- raise AttributeError ("record array has no attribute %s" % attr )
431+ raise AttributeError ("recarray has no attribute %s" % attr )
430432 obj = self .getfield (* res )
431- # if it has fields return a recarray, otherwise return
432- # normal array
433- if obj .dtype .fields :
434- return obj
435- if obj .dtype .char in 'SU' :
436- return obj .view (chararray )
437- return obj .view (ndarray )
438433
439- # Save the dictionary
440- # If the attr is a field name and not in the saved dictionary
441- # Undo any "setting" of the attribute and do a setfield
442- # Thus, you can't create attributes on-the-fly that are field names.
434+ # At this point obj will always be a recarray, since (see
435+ # PyArray_GetField) the type of obj is inherited. Next, if obj.dtype is
436+ # non-structured, convert it to an ndarray. If obj is structured leave
437+ # it as a recarray, but make sure to convert to the same dtype.type (eg
438+ # to preserve numpy.record type if present), since nested structured
439+ # fields do not inherit type.
440+ if obj .dtype .fields :
441+ return obj .view (dtype = (self .dtype .type , obj .dtype .descr ))
442+ else :
443+ return obj .view (ndarray )
443444
445+ # Save the dictionary.
446+ # If the attr is a field name and not in the saved dictionary
447+ # Undo any "setting" of the attribute and do a setfield
448+ # Thus, you can't create attributes on-the-fly that are field names.
444449 def __setattr__ (self , attr , val ):
445450 newattr = attr not in self .__dict__
446451 try :
@@ -468,9 +473,17 @@ def __setattr__(self, attr, val):
468473
469474 def __getitem__ (self , indx ):
470475 obj = ndarray .__getitem__ (self , indx )
471- if (isinstance (obj , ndarray ) and obj .dtype .isbuiltin ):
472- return obj .view (ndarray )
473- return obj
476+
477+ # copy behavior of getattr, except that here
478+ # we might also be returning a single element
479+ if isinstance (obj , ndarray ):
480+ if obj .dtype .fields :
481+ return obj .view (dtype = (self .dtype .type , obj .dtype .descr ))
482+ else :
483+ return obj .view (type = ndarray )
484+ else :
485+ # return a single element
486+ return obj
474487
475488 def __repr__ (self ) :
476489 ret = ndarray .__repr__ (self )
@@ -489,8 +502,6 @@ def field(self, attr, val=None):
489502 obj = self .getfield (* res )
490503 if obj .dtype .fields :
491504 return obj
492- if obj .dtype .char in 'SU' :
493- return obj .view (chararray )
494505 return obj .view (ndarray )
495506 else :
496507 return self .setfield (val , * res )
@@ -601,7 +612,7 @@ def fromrecords(recList, dtype=None, shape=None, formats=None, names=None,
601612 >>> r.col1
602613 array([456, 2])
603614 >>> r.col2
604- chararray (['dbe', 'de'],
615+ array (['dbe', 'de'],
605616 dtype='|S3')
606617 >>> import pickle
607618 >>> print pickle.loads(pickle.dumps(r))
0 commit comments