-
Notifications
You must be signed in to change notification settings - Fork 63
Expand file tree
/
Copy pathlegend.py
More file actions
372 lines (280 loc) · 11.7 KB
/
legend.py
File metadata and controls
372 lines (280 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
from functools import partial
from collections import OrderedDict
from typing import Iterable
import numpy as np
import pygfx
from ..utils.enums import RenderQueue
from ..graphics import Graphic
from ..graphics.features import GraphicFeatureEvent
from ..graphics import LineGraphic, ScatterGraphic, ImageGraphic
from ..utils import mesh_masks
class LegendItem:
def __init__(
self,
label: str,
color: pygfx.Color,
):
"""
Parameters
----------
label: str
color: pygfx.Color
"""
self._label = label
self._color = color
class LineLegendItem(LegendItem):
def __init__(
self, parent, graphic: LineGraphic, label: str, position: tuple[int, int]
):
"""
Parameters
----------
graphic: LineGraphic
label: str
position: [x, y]
"""
if label is not None:
pass
elif graphic.name is not None:
pass
else:
raise ValueError(
"Must specify `label` or Graphic must have a `name` to auto-use as the label"
)
# for now only support lines with a single color
if np.unique(graphic.colors.value, axis=0).shape[0] > 1:
raise ValueError("Use colorbars for multi-colored lines, not legends")
color = pygfx.Color(np.unique(graphic.colors.value, axis=0).ravel())
self._parent = parent
super().__init__(label, color)
graphic.colors.add_event_handler(self._update_color)
# construct Line WorldObject
data = np.array([[0, 0, 0], [3, 0, 0]], dtype=np.float32)
self._line_world_object = pygfx.Line(
geometry=pygfx.Geometry(positions=data),
material=pygfx.LineMaterial(
alpha_mode="blend",
render_queue=RenderQueue.overlay,
thickness=8,
color=self._color,
depth_write=False,
depth_test=False,
),
)
# self._line_world_object.world.x = position[0]
self._label_world_object = pygfx.Text(
text=str(label),
font_size=6,
screen_space=False,
anchor="middle-left",
material=pygfx.TextMaterial(
alpha_mode="blend",
aa=True,
render_queue=RenderQueue.overlay,
color="w",
outline_color="w",
outline_thickness=0,
depth_write=False,
depth_test=False,
),
)
self.world_object = pygfx.Group()
self.world_object.add(self._line_world_object, self._label_world_object)
self.world_object.world.x = position[0]
# add 10 to x to account for space for the line
self._label_world_object.world.x = position[0] + 10
self.world_object.world.y = position[1]
self.world_object.add_event_handler(
partial(self._highlight_graphic, graphic), "click"
)
@property
def label(self) -> str:
return self._label
@label.setter
def label(self, text: str):
self._parent._check_label_unique(text)
self._label_world_object.geometry.set_text(text)
def _update_color(self, ev: GraphicFeatureEvent):
new_color = ev.info["value"]
if np.unique(new_color, axis=0).shape[0] > 1:
raise ValueError(
"LegendError: LineGraphic colors no longer appropriate for legend"
)
self._color = new_color[0]
self._line_world_object.material.color = pygfx.Color(self._color)
def _highlight_graphic(self, graphic: Graphic, ev):
graphic_color = pygfx.Color(np.unique(graphic.colors.value, axis=0).ravel())
if graphic_color == self._parent.highlight_color:
graphic.colors = self._color
else:
# hacky but fine for now
orig_color = pygfx.Color(self._color)
graphic.colors = self._parent.highlight_color
self._color = orig_color
class Legend(Graphic):
def __init__(
self,
plot_area,
highlight_color: str | tuple | np.ndarray = "w",
max_rows: int = 5,
*args,
**kwargs,
):
"""
Parameters
----------
plot_area: Union[Plot, Subplot, Dock]
plot area to put the legend in
highlight_color: Union[str, tuple, np.ndarray], default "w"
highlight color
max_rows: int, default 5
maximum number of rows allowed in the legend
"""
self._graphics: list[Graphic] = list()
# hex id of Graphic, i.e. graphic._fpl_address are the keys
self._items: OrderedDict[str:LegendItem] = OrderedDict()
super().__init__(*args, **kwargs)
group = pygfx.Group()
self._legend_items_group = pygfx.Group()
self._set_world_object(group)
self._mesh = pygfx.Mesh(
pygfx.box_geometry(50, 10, 1),
pygfx.MeshBasicMaterial(
alpha_mode="blend",
render_queue=RenderQueue.overlay,
color=pygfx.Color([0.1, 0.1, 0.1, 1]),
wireframe_thickness=10,
depth_write=False,
depth_test=False,
),
)
# Plane gets rendered before text and line
self._mesh.render_order = -1
self.world_object.add(self._mesh)
self.world_object.add(self._legend_items_group)
self.highlight_color = pygfx.Color(highlight_color)
self._plot_area = plot_area
self._plot_area.add_graphic(self)
if self._plot_area.__class__.__name__ == "Dock":
if self._plot_area.size < 1:
self._plot_area.size = 100
# TODO: refactor with "moveable graphic" base class once that's done
self._mesh.add_event_handler(self._pointer_down, "pointer_down")
self._plot_area.renderer.add_event_handler(self._pointer_move, "pointer_move")
self._plot_area.renderer.add_event_handler(self._pointer_up, "pointer_up")
self._last_position = None
self._initial_controller_state = self._plot_area.controller.enabled
self._max_rows = max_rows
self._row_counter = 0
self._col_counter = 0
def graphics(self) -> tuple[Graphic, ...]:
return tuple(self._graphics)
def _check_label_unique(self, label):
for legend_item in self._items.values():
if legend_item.label == label:
raise ValueError(
f"You have passed the label '{label}' which is already used for another legend item. "
f"All labels within a legend must be unique."
)
def add_graphic(self, graphic: Graphic, label: str = None):
if graphic in self._graphics:
raise KeyError(
f"Graphic already exists in legend with label: '{self._items[graphic._fpl_address].label}'"
)
self._check_label_unique(label)
new_col_ix = self._col_counter
new_row_ix = self._row_counter
x_pos = 0
y_pos = 0
if self._row_counter == self._max_rows:
# set counters
new_col_ix = self._col_counter + 1
# get x position offset for this new column of LegendItems
# start by getting the LegendItems in the previous column
prev_column_items: list[LegendItem] = list(self._items.values())[
-self._max_rows :
]
# x position of LegendItems in previous column
x_pos = prev_column_items[-1].world_object.world.x
max_width = 0
# get width of widest LegendItem in previous column to add to x_pos offset for this column
for item in prev_column_items:
bbox = item.world_object.get_world_bounding_box()
width, height, depth = np.ptp(bbox, axis=0)
max_width = max(max_width, width)
# x position offset for this new column
x_pos = x_pos + max_width + 15 # add 15 for spacing
# rest row index for next iteration
new_row_ix = 1
else:
if len(self._items) > 0:
x_pos = list(self._items.values())[-1].world_object.world.x
y_pos = new_row_ix * -10
new_row_ix = self._row_counter + 1
if isinstance(graphic, LineGraphic):
legend_item = LineLegendItem(self, graphic, label, position=(x_pos, y_pos))
else:
raise ValueError("Legend only supported for LineGraphic for now.")
self._legend_items_group.add(legend_item.world_object)
self._reset_mesh_dims()
self._graphics.append(graphic)
self._items[graphic._fpl_address] = legend_item
graphic.add_event_handler(partial(self.remove_graphic, graphic), "deleted")
self._col_counter = new_col_ix
self._row_counter = new_row_ix
def _reset_mesh_dims(self):
bbox = self._legend_items_group.get_world_bounding_box()
width, height, _ = np.ptp(bbox, axis=0)
self._mesh.geometry.positions.data[mesh_masks.x_right] = width + 7
self._mesh.geometry.positions.data[mesh_masks.x_left] = -5
self._mesh.geometry.positions.data[mesh_masks.y_bottom] = 0
self._mesh.geometry.positions.data[mesh_masks.y_bottom] = -height - 3
self._mesh.geometry.positions.update_range()
def remove_graphic(self, graphic: Graphic):
self._graphics.remove(graphic)
legend_item = self._items.pop(graphic._fpl_address)
self._legend_items_group.remove(legend_item.world_object)
self._reset_item_positions()
def _reset_item_positions(self):
for i, (graphic_loc, legend_item) in enumerate(self._items.items()):
y_pos = i * -10
legend_item.world_object.world.y = y_pos
self._reset_mesh_dims()
def reorder(self, labels: Iterable[str]):
all_labels = [legend_item.label for legend_item in self._items.values()]
if not set(labels) == set(all_labels):
raise ValueError("Must pass all existing legend labels")
new_items = OrderedDict()
for label in labels:
for graphic_loc, legend_item in self._items.items():
if label == legend_item.label:
new_items[graphic_loc] = self._items.pop(graphic_loc)
break
self._items = new_items
self._reset_item_positions()
def _pointer_down(self, ev):
self._last_position = self._plot_area.map_screen_to_world(ev)
self._initial_controller_state = self._plot_area.controller.enabled
def _pointer_move(self, ev):
if self._last_position is None:
return
self._plot_area.controller.enabled = False
world_pos = self._plot_area.map_screen_to_world(ev)
# outside viewport
if world_pos is None:
return
delta = world_pos - self._last_position
self.world_object.world.x = self.world_object.world.x + delta[0]
self.world_object.world.y = self.world_object.world.y + delta[1]
self._last_position = world_pos
self._plot_area.controller.enabled = self._initial_controller_state
def _pointer_up(self, ev):
self._last_position = None
if self._initial_controller_state is not None:
self._plot_area.controller.enabled = self._initial_controller_state
def __getitem__(self, graphic: Graphic) -> LegendItem:
if not isinstance(graphic, Graphic):
raise TypeError("Must index Legend with Graphics")
if graphic._fpl_address not in self._items.keys():
raise KeyError("Graphic not in legend")
return self._items[graphic._fpl_address]