8000 MNT: Use get_offsets() to handle the default/None case · matplotlib/matplotlib@780bc3c · GitHub
[go: up one dir, main page]

Skip to content

Commit 780bc3c

Browse files
committed
MNT: Use get_offsets() to handle the default/None case
This avoids carrying around an extra offsetsNone/has_offsets variable to keep track of the default case. Instead calling get_offsets() to return the default zeros case, and then internally checking the None/default case in get_datalim() which is the only place where this information is needed.
1 parent de4a94c commit 780bc3c

File tree

1 file changed

+44
-47
lines changed

1 file changed

+44
-47
lines changed

lib/matplotlib/collections.py

Lines changed: 44 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ class Collection(artist.Artist, cm.ScalarMappable):
6262
mappable will be used to set the ``facecolors`` and ``edgecolors``,
6363
ignoring those that were manually passed in.
6464
"""
65-
_offsets = np.zeros((0, 2))
6665
#: Either a list of 3x3 arrays or an Nx3x3 array (representing N
6766
#: transforms), suitable for the `all_transforms` argument to
6867
#: `~matplotlib.backend_bases.RendererBase.draw_path_collection`;
@@ -193,11 +192,8 @@ def __init__(self,
193192
else:
194193
self._joinstyle = None
195194

196-
# default to zeros
197-
self._offsets = np.zeros((1, 2))
198-
self._has_offsets = offsets is not None
199-
200-
if self._has_offsets:
195+
self._offsets = offsets
196+
if offsets is not None:
201197
offsets = np.asanyarray(offsets, float)
202198
# Broadcast (2,) -> (1, 2) but nothing else.
203199
if offsets.shape == (2,):
@@ -264,9 +260,12 @@ def get_datalim(self, transData):
264260
# if the offsets are in some coords other than data,
265261
# then don't use them for autoscaling.
266262
return transforms.Bbox.null()
267-
offsets = self._offsets
263+
offsets = self.get_offsets()
268264

269265
paths = self.get_paths()
266+
if not len(paths):
267+
# No paths to transform
268+
return transforms.Bbox.null()
270269

271270
if not transform.is_affine:
272271
paths = [transform.transform_path_non_affine(p) for p in paths]
@@ -275,35 +274,34 @@ def get_datalim(self, transData):
275274
# transforms.get_affine().contains_branch(transData). But later,
276275
# be careful to only apply the affine part that remains.
277276

278-
if isinstance(offsets, np.ma.MaskedArray):
279-
offsets = offsets.filled(np.nan)
280-
# get_path_collection_extents handles nan but not masked arrays
281-
282-
if len(paths):
283-
if any(transform.contains_branch_seperately(transData)):
284-
# collections that are just in data units (like quiver)
285-
# can properly have the axes limits set by their shape +
286-
# offset. LineCollections that have no offsets can
287-
# also use this algorithm (like streamplot).
288-
return mpath.get_path_collection_extents(
289-
transform.get_affine() - transData, paths,
290-
self.get_transforms(),
291-
offset_trf.transform_non_affine(offsets),
292-
offset_trf.get_affine().frozen())
293-
294-
if self._has_offsets:
295-
# this is for collections that have their paths (shapes)
296-
# in physical, axes-relative, or figure-relative units
297-
# (i.e. like scatter). We can't uniquely set limits based on
298-
# those shapes, so we just set the limits based on their
299-
# location.
300-
offsets = (offset_trf - transData).transform(offsets)
301-
# note A-B means A B^{-1}
302-
offsets = np.ma.masked_invalid(offsets)
303-
if not offsets.mask.all():
304-
bbox = transforms.Bbox.null()
305-
bbox.update_from_data_xy(offsets)
306-
return bbox
277+
if any(transform.contains_branch_seperately(transData)):
278+
# collections that are just in data units (like quiver)
279+
# can properly have the axes limits set by their shape +
280+
# offset. LineCollections that have no offsets can
281+
# also use this algorithm (like streamplot).
282+
if isinstance(offsets, np.ma.MaskedArray):
283+
offsets = offsets.filled(np.nan)
284+
# get_path_collection_extents handles nan but not masked arrays
285+
return mpath.get_path_collection_extents(
286+
transform.get_affine() - transData, paths,
287+
self.get_transforms(),
288+
offset_trf.transform_non_affine(offsets),
289+
offset_trf.get_affine().frozen())
290+
291+
# NOTE: None is the default case where no offsets were passed in
292+
if self._offsets is not None:
293+
# this is for collections that have their paths (shapes)
294+
# in physical, axes-relative, or figure-relative units
295+
# (i.e. like scatter). We can't uniquely set limits based on
296+
# those shapes, so we just set the limits based on their
297+
# location.
298+
offsets = (offset_trf - transData).transform(offsets)
299+
# note A-B means A B^{-1}
300+
offsets = np.ma.masked_invalid(offsets)
301+
if not offsets.mask.all():
302+
bbox = transforms.Bbox.null()
303+
bbox.update_from_data_xy(offsets)
304+
return bbox
307305
return transforms.Bbox.null()
308306

309307
def get_window_extent(self, renderer):
@@ -316,7 +314,7 @@ def _prepare_points(self):
316314

317315
transform = self.get_transform()
318316
offset_trf = self.get_offset_transform()
319-
offsets = self._offsets
317+
offsets = self.get_offsets()
320318
paths = self.get_paths()
321319

322320
if self.have_units():
@@ -327,10 +325,9 @@ def _prepare_points(self):
327325
xs = self.convert_xunits(xs)
328326
ys = self.convert_yunits(ys)
329327
paths.append(mpath.Path(np.column_stack([xs, ys]), path.codes))
330-
if offsets.size:
331-
xs = self.convert_xunits(offsets[:, 0])
332-
ys = self.convert_yunits(offsets[:, 1])
333-
offsets = np.column_stack([xs, ys])
328+
xs = self.convert_xunits(offsets[:, 0])
329+
ys = self.convert_yunits(offsets[:, 1])
330+
offsets = np.column_stack([xs, ys])
334331

335332
if not transform.is_affine:
336333
paths = [transform.transform_path_non_affine(path)
@@ -559,7 +556,8 @@ def set_offsets(self, offsets):
559556

560557
def get_offsets(self):
561558
"""Return the offsets for the collection."""
562-
return self._offsets
559+
# Default to zeros in the no-offset (None) case
560+
return np.zeros((1, 2)) if self._offsets is None else self._offsets
563561

564562
def _get_default_linewidth(self):
565563
# This may be overridden in a subclass.
@@ -2156,13 +2154,12 @@ def draw(self, renderer):
21562154
renderer.open_group(self.__class__.__name__, self.get_gid())
21572155
transform = self.get_transform()
21582156
offset_trf = self.get_offset_transform()
2159-
offsets = self._offsets
2157+
offsets = self.get_offsets()
21602158

21612159
if self.have_units():
2162-
if len(self._offsets):
2163-
xs = self.convert_xunits(self._offsets[:, 0])
2164-
ys = self.convert_yunits(self._offsets[:, 1])
2165-
offsets = np.column_stack([xs, ys])
2160+
xs = self.convert_xunits(offsets[:, 0])
2161+
ys = self.convert_yunits(offsets[:, 1])
2162+
offsets = np.column_stack([xs, ys])
21662163

21672164
self.update_scalarmappable()
21682165

0 commit comments

Comments
 (0)
0