8000 Merge pull request #21352 from anntzer/hexbin · matplotlib/matplotlib@3331777 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

8000
Appearance settings

Commit 3331777

Browse files
authored
Merge pull request #21352 from anntzer/hexbin
Refactor hexbin().
2 parents cfcf737 + ee206a1 commit 3331777

File tree

1 file changed

+100
-160
lines changed

1 file changed

+100
-160
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 100 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -4669,110 +4669,88 @@ def reduce_C_function(C: array) -> float
46694669
nx = gridsize
46704670
ny = int(nx / math.sqrt(3))
46714671
# Count the number of data in each hexagon
4672-
x = np.array(x, float)
4673-
y = np.array(y, float)
4672+
x = np.asarray(x, float)
4673+
y = np.asarray(y, float)
46744674

4675-
if marginals:
4676-
xorig = x.copy()
4677-
yorig = y.copy()
4675+
# Will be log()'d if necessary, and then rescaled.
4676+
tx = x
4677+
ty = y
46784678

46794679
if xscale == 'log':
46804680
if np.any(x <= 0.0):
4681-
raise ValueError("x contains non-positive values, so can not"
4682-
" be log-scaled")
4683-
x = np.log10(x)
4681+
raise ValueError("x contains non-positive values, so can not "
4682+
"be log-scaled")
4683+
tx = np.log10(tx)
46844684
if yscale == 'log':
46854685
if np.any(y <= 0.0):
4686-
raise ValueError("y contains non-positive values, so can not"
4687-
" be log-scaled")
4688-
y = np.log10(y)
4686+
raise ValueError("y contains non-positive values, so can not "
4687+
"be log-scaled")
4688+
ty = np.log10(ty)
46894689
if extent is not None:
46904690
xmin, xmax, ymin, ymax = extent
46914691
else:
4692-
xmin, xmax = (np.min(x), np.max(x)) if len(x) else (0, 1)
4693-
ymin, ymax = (np.min(y), np.max(y)) if len(y) else (0, 1)
4692+
xmin, xmax = (tx.min(), tx.max()) if len(x) else (0, 1)
4693+
ymin, ymax = (ty.min(), ty.max()) if len(y) else (0, 1)
46944694

46954695
# to avoid issues with singular data, expand the min/max pairs
46964696
xmin, xmax = mtransforms.nonsingular(xmin, xmax, expander=0.1)
46974697
ymin, ymax = mtransforms.nonsingular(ymin, ymax, expander=0.1)
46984698

4699+
nx1 = nx + 1
4700+
ny1 = ny + 1
4701+
nx2 = nx
4702+
ny2 = ny
4703+
n = nx1 * ny1 + nx2 * ny2
4704+
46994705
# In the x-direction, the hexagons exactly cover the region from
47004706
# xmin to xmax. Need some padding to avoid roundoff errors.
47014707
padding = 1.e-9 * (xmax - xmin)
47024708
xmin -= padding
47034709
xmax += padding
47044710
sx = (xmax - xmin) / nx
47054711
sy = (ymax - ymin) / ny
4706-
4707-
x = (x - xmin) / sx
4708-
y = (y - ymin) / sy
4709-
ix1 = np.round(x).astype(int)
4710-
iy1 = np.round(y).astype(int)
4711-
ix2 = np.floor(x).astype(int)
4712-
iy2 = np.floor(y).astype(int)
4713-
4714-
nx1 = nx + 1
4715-
ny1 = ny + 1
4716-
nx2 = nx
4717-
ny2 = ny
4718-
n = nx1 * ny1 + nx2 * ny2
4719-
4720-
d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2
4721-
d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2
4712+
# Positions in hexagon index coordinates.
4713+
ix = (tx - xmin) / sx
4714+
iy = (ty - ymin) / sy
4715+
ix1 = np.round(ix).astype(int)
4716+
iy1 = np.round(iy).astype(int)
4717+
ix2 = np.floor(ix).astype(int)
4718+
iy2 = np.floor(iy).astype(int)
4719+
# flat indices, plus one so that out-of-range points go to position 0.
4720+
i1 = np.where((0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1),
4721+
ix1 * ny1 + iy1 + 1, 0)
4722+
i2 = np.where((0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2),
4723+
ix2 * ny2 + iy2 + 1, 0)
4724+
4725+
d1 = (ix - ix1) ** 2 + 3.0 * (iy - iy1) ** 2
4726+
d2 = (ix - ix2 - 0.5) ** 2 + 3.0 * (iy - iy2 - 0.5) ** 2
47224727
bdist = (d1 < d2)
4723-
if C is None:
4724-
lattice1 = np.zeros((nx1, ny1))
4725-
lattice2 = np.zeros((nx2, ny2))
4726-
c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist
4727-
c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist
4728-
np.add.at(lattice1, (ix1[c1], iy1[c1]), 1)
4729-
np.add.at(lattice2, (ix2[c2], iy2[c2]), 1)
4730-
if mincnt is not None:
4731-
lattice1[lattice1 < mincnt] = np.nan
4732-
lattice2[lattice2 < mincnt] = np.nan
4733-
accum = np.concatenate([lattice1.ravel(), lattice2.ravel()])
4734-
good_idxs = ~np.isnan(accum)
47354728

4729+
if C is None: # [1:] drops out-of-range points.
4730+
counts1 = np.bincount(i1[bdist], minlength=1 + nx1 * ny1)[1:]
4731+
counts2 = np.bincount(i2[~bdist], minlength=1 + nx2 * ny2)[1:]
4732+
accum = np.concatenate([counts1, counts2]).astype(float)
4733+
if mincnt is not None:
4734+
accum[accum < mincnt] = np.nan
4735+
C = np.ones(len(x))
47364736
else:
4737-
if mincnt is None:
4738-
mincnt = 0
4739-
4740-
# create accumulation arrays
4741-
lattice1 = np.empty((nx1, ny1), dtype=object)
4742-
for i in range(nx1):
4743-
for j in range(ny1):
4744-
lattice1[i, j] = []
4745-
lattice2 = np.empty((nx2, ny2), dtype=object)
4746-
for i in range(nx2):
4747-
for j in range(ny2):
4748-
lattice2[i, j] = []
4749-
4737+
# store the C values in a list per hexagon index
4738+
Cs_at_i1 = [[] for _ in range(1 + nx1 * ny1)]
4739+
Cs_at_i2 = [[] for _ in range(1 + nx2 * ny2)]
47504740
for i in range(len(x)):
47514741
if bdist[i]:
4752-
if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1:
4753-
lattice1[ix1[i], iy1[i]].append(C[i])
4742+
Cs_at_i1[i1[i]].append(C[i])
47544743
else:
4755-
if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2:
4756-
lattice2[ix2[i], iy2[i]].append(C[i])
4757-
4758-
for i in range(nx1):
4759-
for j in range(ny1):
4760-
vals = lattice1[i, j]
4761-
if len(vals) > mincnt:
4762-
lattice1[i, j] = reduce_C_function(vals)
4763-
else:
4764-
lattice1[i, j] = np.nan
4765-
for i in range(nx2):
4766-
for j in range(ny2):
4767-
vals = lattice2[i, j]
4768-
if len(vals) > mincnt:
4769-
lattice2[i, j] = reduce_C_function(vals)
4770-
else:
4771-
lattice2[i, j] = np.nan
4744+
Cs_at_i2[i2[i]].append(C[i])
4745+
if mincnt is None:
4746+
mincnt = 0
4747+
accum = np.array(
4748+
[reduce_C_function(acc) if len(acc) > mincnt else np.nan
4749+
for Cs_at_i in [Cs_at_i1, Cs_at_i2]
4750+
for acc in Cs_at_i[1:]], # [1:] drops out-of-range points.
4751+
float)
47724752

4773-
accum = np.concatenate([lattice1.astype(float).ravel(),
4774-
lattice2.astype(float).ravel()])
4775-
good_idxs = ~np.isnan(accum)
4753+
good_idxs = ~np.isnan(accum)
47764754

47774755
offsets = np.zeros((n, 2), float)
47784756
offsets[:nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
@@ -4830,8 +4808,7 @@ def reduce_C_function(C: array) -> float
48304808
vmin = vmax = None
48314809
bins = None
48324810

4833-
# autoscale the norm with current accum values if it hasn't
4834-
# been set
4811+
# autoscale the norm with current accum values if it hasn't been set
48354812
if norm is not None:
48364813
if norm.vmin is None and norm.vmax is None:
48374814
norm.autoscale(accum)
@@ -4861,92 +4838,55 @@ def reduce_C_function(C: array) -> float
48614838
return collection
48624839

48634840
# Process marginals
4864-
if C is None:
4865-
C = np.ones(len(x))
4841+
bars = []
4842+
for zname, z, zmin, zmax, zscale, nbins in [
4843+
("x", x, xmin, xmax, xscale, nx),
4844+
("y", y, ymin, ymax, yscale, 2 * ny),
4845+
]:
48664846

4867-
def coarse_bin(x, y, bin_edges):
4868-
"""
4869-
Sort x-values into bins defined by *bin_edges*, then for all the
4870-
corresponding y-values in each bin use *reduce_c_function* to
4871-
compute the bin value.
4872-
"""
4873-
nbins = len(bin_edges) - 1
4874-
# Sort x-values into bins
4875-
bin_idxs = np.searchsorted(bin_edges, x) - 1
4876-
mus = np.zeros(nbins) * np.nan
4847+
if zscale == "log":
4848+
bin_edges = np.geomspace(zmin, zmax, nbins + 1)
4849+
else:
4850+
bin_edges = np.linspace(zmin, zmax, nbins + 1)
4851+
4852+
verts = np.empty((nbins, 4, 2))
4853+
verts[:, 0, 0] = verts[:, 1, 0] = bin_edges[:-1]
4854+
verts[:, 2, 0] = verts[:, 3, 0] = bin_edges[1:]
4855+
verts[:, 0, 1] = verts[:, 3, 1] = .00
4856+
verts[:, 1, 1] = verts[:, 2, 1] = .05
4857+
if zname == "y":
4858+
verts = verts[:, :, ::-1] # Swap x and y.
4859+
4860+
# Sort z-values into bins defined by bin_edges.
4861+
bin_idxs = np.searchsorted(bin_edges, z) - 1
4862+
values = np.empty(nbins)
48774863
for i in range(nbins):
4878-
# Get y-values for each bin
4879-
yi = y[bin_idxs == i]
4880-
if len(yi) > 0:
4881-
mus[i] = reduce_C_function(yi)
4882-
return mus
4883-
4884-
if xscale == 'log':
4885-
bin_edges = np.geomspace(xmin, xmax, nx + 1)
4886-
else:
4887-
bin_edges = np.linspace(xmin, xmax, nx + 1)
4888-
xcoarse = coarse_bin(xorig, C, bin_edges)
4889-
4890-
verts, values = [], []
4891-
for bin_left, bin_right, val in zip(
4892-
bin_edges[:-1], bin_edges[1:], xcoarse):
4893-
if np.isnan(val):
4894-
continue
4895-
verts.append([(bin_left, 0),
4896-
(bin_left, 0.05),
4897-
(bin_right, 0.05),
4898-
(bin_right, 0)])
4899-
values.append(val)
4900-
4901-
values = np.array(values)
4902-
trans = self.get_xaxis_transform(which='grid')
4903-
4904-
hbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
4905-
4906-
hbar.set_array(values)
4907-
hbar.set_cmap(cmap)
4908-
hbar.set_norm(norm)
4909-
hbar.set_alpha(alpha)
4910-
hbar.update(kwargs)
4911-
self.add_collection(hbar, autolim=False)
4912-
4913-
if yscale == 'log':
4914-
bin_edges = np.geomspace(ymin, ymax, 2 * ny + 1)
4915-
else:
4916-
bin_edges = np.linspace(ymin, ymax, 2 * ny + 1)
4917-
ycoarse = coarse_bin(yorig, C, bin_edges)
4918-
4919-
verts, values = [], []
4920-
for bin_bottom, bin_top, val in zip(
4921-
bin_edges[:-1], bin_edges[1:], ycoarse):
4922-
if np.isnan(val):
4923-
continue
4924-
verts.append([(0, bin_bottom),
4925-
(0, bin_top),
4926-
(0.05, bin_top),
4927-
(0.05, bin_bottom)])
4928-
values.append(val)
4929-
4930-
values = np.array(values)
4931-
4932-
trans = self.get_yaxis_transform(which='grid')
4933-
4934-
vbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
4935-
vbar.set_array(values)
4936-
vbar.set_cmap(cmap)
4937-
vbar.set_norm(norm)
4938-
vbar.set_alpha(alpha)
4939-
vbar.update(kwargs)
4940-
self.add_collection(vbar, autolim=False)
4941-
4942-
collection.hbar = hbar
4943-
collection.vbar = vbar
4864+
# Get C-values for each bin, and compute bin value with
4865+
# reduce_C_function.
4866+
ci = C[bin_idxs == i]
4867+
values[i] = reduce_C_function(ci) if len(ci) > 0 else np.nan
4868+
4869+
mask = ~np.isnan(values)
4870+
verts = verts[mask]
4871+
values = values[mask]
4872+
4873+
trans = getattr(self, f"get_{zname}axis_transform")(which="grid")
4874+
bar = mcoll.PolyCollection(
4875+
verts, transform=trans, edgecolors="face")
4876+
bar.set_array(values)
4877+
bar.set_cmap(cmap)
4878+
bar.set_norm(norm)
4879+
bar.set_alpha(alpha)
4880+
bar.update(kwargs)
4881+
bars.append(self.add_collection(bar, autolim=False))
4882+
4883+
collection.hbar, collection.vbar = bars
49444884

49454885
def on_changed(collection):
4946-
hbar.set_cmap(collection.get_cmap())
4947-
hbar.set_clim(collection.get_clim())
4948-
vbar.set_cmap(collection.get_cmap())
4949-
vbar.set_clim(collection.get_clim())
4886+
collection.hbar.set_cmap(collection.get_cmap())
4887+
collection.hbar.set_cmap(collection.get_cmap())
4888+
collection.vbar.set_clim(collection.get_clim())
4889+
collection.vbar.set_clim(collection.get_clim())
49504890

49514891
collection.callbacks.connect('changed', on_changed)
49524892

0 commit comments

Comments
 (0)
0