@@ -4669,110 +4669,88 @@ def reduce_C_function(C: array) -> float
4669
4669
nx = gridsize
4670
4670
ny = int (nx / math .sqrt (3 ))
4671
4671
# Count the number of data in each hexagon
4672
- x = np .array (x , float )
4673
- y = np .array (y , float )
4672
+ x = np .asarray (x , float )
4673
+ y = np .asarray (y , float )
4674
4674
4675
- if marginals :
4676
- xorig = x . copy ()
4677
- yorig = y . copy ()
4675
+ # Will be log()'d if necessary, and then rescaled.
4676
+ tx = x
4677
+ ty = y
4678
4678
4679
4679
if xscale == 'log' :
4680
4680
if np .any (x <= 0.0 ):
4681
- raise ValueError ("x contains non-positive values, so can not"
4682
- " be log-scaled" )
4683
- x = np .log10 (x )
4681
+ raise ValueError ("x contains non-positive values, so can not "
4682
+ "be log-scaled" )
4683
+ tx = np .log10 (tx )
4684
4684
if yscale == 'log' :
4685
4685
if np .any (y <= 0.0 ):
4686
- raise ValueError ("y contains non-positive values, so can not"
4687
- " be log-scaled" )
4688
- y = np .log10 (y )
4686
+ raise ValueError ("y contains non-positive values, so can not "
4687
+ "be log-scaled" )
4688
+ ty = np .log10 (ty )
4689
4689
if extent is not None :
4690
4690
xmin , xmax , ymin , ymax = extent
4691
4691
else :
4692
- xmin , xmax = (np .min (x ), np .max (x )) if len (x ) else (0 , 1 )
4693
- ymin , ymax = (np .min (y ), np .max (y )) if len (y ) else (0 , 1 )
4692
+ xmin , xmax = (tx .min (), tx .max ()) if len (x ) else (0 , 1 )
4693
+ ymin , ymax = (ty .min (), ty .max ()) if len (y ) else (0 , 1 )
4694
4694
4695
4695
# to avoid issues with singular data, expand the min/max pairs
4696
4696
xmin , xmax = mtransforms .nonsingular (xmin , xmax , expander = 0.1 )
4697
4697
ymin , ymax = mtransforms .nonsingular (ymin , ymax , expander = 0.1 )
4698
4698
4699
+ nx1 = nx + 1
4700
+ ny1 = ny + 1
4701
+ nx2 = nx
4702
+ ny2 = ny
4703
+ n = nx1 * ny1 + nx2 * ny2
4704
+
4699
4705
# In the x-direction, the hexagons exactly cover the region from
4700
4706
# xmin to xmax. Need some padding to avoid roundoff errors.
4701
4707
padding = 1.e-9 * (xmax - xmin )
4702
4708
xmin -= padding
4703
4709
xmax += padding
4704
4710
sx = (xmax - xmin ) / nx
4705
4711
sy = (ymax - ymin ) / ny
4706
-
4707
- x = (x - xmin ) / sx
4708
- y = (y - ymin ) / sy
4709
- ix1 = np .round (x ).astype (int )
4710
- iy1 = np .round (y ).astype (int )
4711
- ix2 = np .floor (x ).astype (int )
4712
- iy2 = np .floor (y ).astype (int )
4713
-
4714
- nx1 = nx + 1
4715
- ny1 = ny + 1
4716
- nx2 = nx
4717
- ny2 = ny
4718
- n = nx1 * ny1 + nx2 * ny2
4719
-
4720
- d1 = (x - ix1 ) ** 2 + 3.0 * (y - iy1 ) ** 2
4721
- d2 = (x - ix2 - 0.5 ) ** 2 + 3.0 * (y - iy2 - 0.5 ) ** 2
4712
+ # Positions in hexagon index coordinates.
4713
+ ix = (tx - xmin ) / sx
4714
+ iy = (ty - ymin ) / sy
4715
+ ix1 = np .round (ix ).astype (int )
4716
+ iy1 = np .round (iy ).astype (int )
4717
+ ix2 = np .floor (ix ).astype (int )
4718
+ iy2 = np .floor (iy ).astype (int )
4719
+ # flat indices, plus one so that out-of-range points go to position 0.
4720
+ i1 = np .where ((0 <= ix1 ) & (ix1 < nx1 ) & (0 <= iy1 ) & (iy1 < ny1 ),
4721
+ ix1 * ny1 + iy1 + 1 , 0 )
4722
+ i2 = np .where ((0 <= ix2 ) & (ix2 < nx2 ) & (0 <= iy2 ) & (iy2 < ny2 ),
4723
+ ix2 * ny2 + iy2 + 1 , 0 )
4724
+
4725
+ d1 = (ix - ix1 ) ** 2 + 3.0 * (iy - iy1 ) ** 2
4726
+ d2 = (ix - ix2 - 0.5 ) ** 2 + 3.0 * (iy - iy2 - 0.5 ) ** 2
4722
4727
bdist = (d1 < d2 )
4723
- if C is None :
4724
- lattice1 = np .zeros ((nx1 , ny1 ))
4725
- lattice2 = np .zeros ((nx2 , ny2 ))
4726
- c1 = (0 <= ix1 ) & (ix1 < nx1 ) & (0 <= iy1 ) & (iy1 < ny1 ) & bdist
4727
- c2 = (0 <= ix2 ) & (ix2 < nx2 ) & (0 <= iy2 ) & (iy2 < ny2 ) & ~ bdist
4728
- np .add .at (lattice1 , (ix1 [c1 ], iy1 [c1 ]), 1 )
4729
- np .add .at (lattice2 , (ix2 [c2 ], iy2 [c2 ]), 1 )
4730
- if mincnt is not None :
4731
- lattice1 [lattice1 < mincnt ] = np .nan
4732
- lattice2 [lattice2 < mincnt ] = np .nan
4733
- accum = np .concatenate ([lattice1 .ravel (), lattice2 .ravel ()])
4734
- good_idxs = ~ np .isnan (accum )
4735
4728
4729
+ if C is None : # [1:] drops out-of-range points.
4730
+ counts1 = np .bincount (i1 [bdist ], minlength = 1 + nx1 * ny1 )[1 :]
4731
+ counts2 = np .bincount (i2 [~ bdist ], minlength = 1 + nx2 * ny2 )[1 :]
4732
+ accum = np .concatenate ([counts1 , counts2 ]).astype (float )
4733
+ if mincnt is not None :
4734
+ accum [accum < mincnt ] = np .nan
4735
+ C = np .ones (len (x ))
4736
4736
else :
4737
- if mincnt is None :
4738
- mincnt = 0
4739
-
4740
- # create accumulation arrays
4741
- lattice1 = np .empty ((nx1 , ny1 ), dtype = object )
4742
- for i in range (nx1 ):
4743
- for j in range (ny1 ):
4744
- lattice1 [i , j ] = []
4745
- lattice2 = np .empty ((nx2 , ny2 ), dtype = object )
4746
- for i in range (nx2 ):
4747
- for j in range (ny2 ):
4748
- lattice2 [i , j ] = []
4749
-
4737
+ # store the C values in a list per hexagon index
4738
+ Cs_at_i1 = [[] for _ in range (1 + nx1 * ny1 )]
4739
+ Cs_at_i2 = [[] for _ in range (1 + nx2 * ny2 )]
4750
4740
for i in range (len (x )):
4751
4741
if bdist [i ]:
4752
- if 0 <= ix1 [i ] < nx1 and 0 <= iy1 [i ] < ny1 :
4753
- lattice1 [ix1 [i ], iy1 [i ]].append (C [i ])
4742
+ Cs_at_i1 [i1 [i ]].append (C [i ])
4754
4743
else :
4755
- if 0 <= ix2 [i ] < nx2 and 0 <= iy2 [i ] < ny2 :
4756
- lattice2 [ix2 [i ], iy2 [i ]].append (C [i ])
4757
-
4758
- for i in range (nx1 ):
4759
- for j in range (ny1 ):
4760
- vals = lattice1 [i , j ]
4761
- if len (vals ) > mincnt :
4762
- lattice1 [i , j ] = reduce_C_function (vals )
4763
- else :
4764
- lattice1 [i , j ] = np .nan
4765
- for i in range (nx2 ):
4766
- for j in range (ny2 ):
4767
- vals = lattice2 [i , j ]
4768
- if len (vals ) > mincnt :
4769
- lattice2 [i , j ] = reduce_C_function (vals )
4770
- else :
4771
- lattice2 [i , j ] = np .nan
4744
+ Cs_at_i2 [i2 [i ]].append (C [i ])
4745
+ if mincnt is None :
4746
+ mincnt = 0
4747
+ accum = np .array (
4748
+ [reduce_C_function (acc ) if len (acc ) > mincnt else np .nan
4749
+ for Cs_at_i in [Cs_at_i1 , Cs_at_i2 ]
4750
+ for acc in Cs_at_i [1 :]], # [1:] drops out-of-range points.
4751
+ float )
4772
4752
4773
- accum = np .concatenate ([lattice1 .astype (float ).ravel (),
4774
- lattice2 .astype (float ).ravel ()])
4775
- good_idxs = ~ np .isnan (accum )
4753
+ good_idxs = ~ np .isnan (accum )
4776
4754
4777
4755
offsets = np .zeros ((n , 2 ), float )
4778
4756
offsets [:nx1 * ny1 , 0 ] = np .repeat (np .arange (nx1 ), ny1 )
@@ -4830,8 +4808,7 @@ def reduce_C_function(C: array) -> float
4830
4808
vmin = vmax = None
4831
4809
bins = None
4832
4810
4833
- # autoscale the norm with current accum values if it hasn't
4834
- # been set
4811
+ # autoscale the norm with current accum values if it hasn't been set
4835
4812
if norm is not None :
4836
4813
if norm .vmin is None and norm .vmax is None :
4837
4814
norm .autoscale (accum )
@@ -4861,92 +4838,55 @@ def reduce_C_function(C: array) -> float
4861
4838
return collection
4862
4839
4863
4840
# Process marginals
4864
- if C is None :
4865
- C = np .ones (len (x ))
4841
+ bars = []
4842
+ for zname , z , zmin , zmax , zscale , nbins in [
4843
+ ("x" , x , xmin , xmax , xscale , nx ),
4844
+ ("y" , y , ymin , ymax , yscale , 2 * ny ),
4845
+ ]:
4866
4846
4867
- def coarse_bin (x , y , bin_edges ):
4868
- """
4869
- Sort x-values into bins defined by *bin_edges*, then for all the
4870
- corresponding y-values in each bin use *reduce_c_function* to
4871
- compute the bin value.
4872
- """
4873
- nbins = len (bin_edges ) - 1
4874
- # Sort x-values into bins
4875
- bin_idxs = np .searchsorted (bin_edges , x ) - 1
4876
- mus = np .zeros (nbins ) * np .nan
4847
+ if zscale == "log" :
4848
+ bin_edges = np .geomspace (zmin , zmax , nbins + 1 )
4849
+ else :
4850
+ bin_edges = np .linspace (zmin , zmax , nbins + 1 )
4851
+
4852
+ verts = np .empty ((nbins , 4 , 2 ))
4853
+ verts [:, 0 , 0 ] = verts [:, 1 , 0 ] = bin_edges [:- 1 ]
4854
+ verts [:, 2 , 0 ] = verts [:, 3 , 0 ] = bin_edges [1 :]
4855
+ verts [:, 0 , 1 ] = verts [:, 3 , 1 ] = .00
4856
+ verts [:, 1 , 1 ] = verts [:, 2 , 1 ] = .05
4857
+ if zname == "y" :
4858
+ verts = verts [:, :, ::- 1 ] # Swap x and y.
4859
+
4860
+ # Sort z-values into bins defined by bin_edges.
4861
+ bin_idxs = np .searchsorted (bin_edges , z ) - 1
4862
+ values = np .empty (nbins )
4877
4863
for i in range (nbins ):
4878
- # Get y-values for each bin
4879
- yi = y [bin_idxs == i ]
4880
- if len (yi ) > 0 :
4881
- mus [i ] = reduce_C_function (yi )
4882
- return mus
4883
-
4884
- if xscale == 'log' :
4885
- bin_edges = np .geomspace (xmin , xmax , nx + 1 )
4886
- else :
4887
- bin_edges = np .linspace (xmin , xmax , nx + 1 )
4888
- xcoarse = coarse_bin (xorig , C , bin_edges )
4889
-
4890
- verts , values = [], []
4891
- for bin_left , bin_right , val in zip (
4892
- bin_edges [:- 1 ], bin_edges [1 :], xcoarse ):
4893
- if np .isnan (val ):
4894
- continue
4895
- verts .append ([(bin_left , 0 ),
4896
- (bin_left , 0.05 ),
4897
- (bin_right , 0.05 ),
4898
- (bin_right , 0 )])
4899
- values .append (val )
4900
-
4901
- values = np .array (values )
4902
- trans = self .get_xaxis_transform (which = 'grid' )
4903
-
4904
- hbar = mcoll .PolyCollection (verts , transform = trans , edgecolors = 'face' )
4905
-
4906
- hbar .set_array (values )
4907
- hbar .set_cmap (cmap )
4908
- hbar .set_norm (norm )
4909
- hbar .set_alpha (alpha )
4910
- hbar .update (kwargs )
4911
- self .add_collection (hbar , autolim = False )
4912
-
4913
- if yscale == 'log' :
4914
- bin_edges = np .geomspace (ymin , ymax , 2 * ny + 1 )
4915
- else :
4916
- bin_edges = np .linspace (ymin , ymax , 2 * ny + 1 )
4917
- ycoarse = coarse_bin (yorig , C , bin_edges )
4918
-
4919
- verts , values = [], []
4920
- for bin_bottom , bin_top , val in zip (
4921
- bin_edges [:- 1 ], bin_edges [1 :], ycoarse ):
4922
- if np .isnan (val ):
4923
- continue
4924
- verts .append ([(0 , bin_bottom ),
4925
- (0 , bin_top ),
4926
- (0.05 , bin_top ),
4927
- (0.05 , bin_bottom )])
4928
- values .append (val )
4929
-
4930
- values = np .array (values )
4931
-
4932
- trans = self .get_yaxis_transform (which = 'grid' )
4933
-
4934
- vbar = mcoll .PolyCollection (verts , transform = trans , edgecolors = 'face' )
4935
- vbar .set_array (values )
4936
- vbar .set_cmap (cmap )
4937
- vbar .set_norm (norm )
4938
- vbar .set_alpha (alpha )
4939
- vbar .update (kwargs )
4940
- self .add_collection (vbar , autolim = False )
4941
-
4942
- collection .hbar = hbar
4943
- collection .vbar = vbar
4864
+ # Get C-values for each bin, and compute bin value with
4865
+ # reduce_C_function.
4866
+ ci = C [bin_idxs == i ]
4867
+ values [i ] = reduce_C_function (ci ) if len (ci ) > 0 else np .nan
4868
+
4869
+ mask = ~ np .isnan (values )
4870
+ verts = verts [mask ]
4871
+ values = values [mask ]
4872
+
4873
+ trans = getattr (self , f"get_{ zname } axis_transform" )(which = "grid" )
4874
+ bar = mcoll .PolyCollection (
4875
+ verts , transform = trans , edgecolors = "face" )
4876
+ bar .set_array (values )
4877
+ bar .set_cmap (cmap )
4878
+ bar .set_norm (norm )
4879
+ bar .set_alpha (alpha )
4880
+ bar .update (kwargs )
4881
+ bars .append (self .add_collection (bar , autolim = False ))
4882
+
4883
+ collection .hbar , collection .vbar = bars
4944
4884
4945
4885
def on_changed (collection ):
4946
- hbar .set_cmap (collection .get_cmap ())
4947
- hbar .set_clim (collection .get_clim ())
4948
- vbar .set_cmap (collection .get_cmap ())
4949
- vbar .set_clim (collection .get_clim ())
4886
+ collection . hbar .set_cmap (collection .get_cmap ())
4887
+ collection . hbar .set_cmap (collection .get_cmap ())
4888
+ collection . vbar .set_clim (collection .get_clim ())
4889
+ collection . vbar .set_clim (collection .get_clim ())
4950
4890
4951
4891
collection .callbacks .connect ('changed' , on_changed )
4952
4892
0 commit comments