8000 Update get_data to take a ‘flatten’ argument. · pferate/python-api@c8cb566 · GitHub
[go: up one dir, main page]

Skip to content

Commit c8cb566

Browse files
committed
Update get_data to take a ‘flatten’ argument.
1 parent c8c093c commit c8cb566

File tree

2 files changed

+239
-126
lines changed

2 files changed

+239
-126
lines changed

plotly/graph_objs/graph_objs.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,19 +158,36 @@ def strip_style(self):
158158
for plotly_dict in self:
159159
plotly_dict.strip_style()
160160

161-
def get_data(self):
162-
"""Returns the JSON for the plot with non-data elements stripped."""
161+
def get_data(self, flatten=False):
162+
"""
163+
Returns the JSON for the plot with non-data elements stripped.
164+
165+
Flattening may increase the utility of the result.
166+
167+
:param (bool) flatten: {'a': {'b': ''}} --> {'a.b': ''}
168+
:returns: (dict|list) Depending on (flat|unflat)
169+
170+
"""
163171
self.to_graph_objs()
164172
l = list()
165173
for _plotlydict in self:
166-
l += [_plotlydict.get_data()]
174+
l += [_plotlydict.get_data(flatten=flatten)]
167175
del_indicies = [index for index, item in enumerate(self)
168176
if len(item) == 0]
169177
del_ct = 0
170178
for index in del_indicies:
171179
del self[index - del_ct]
172180
del_ct += 1
173-
return l
181+
182+
if flatten:
183+
d = {}
184+
for i, e in enumerate(l):
185+
for k, v in e.items():
186+
key = "{}.{}".format(i, k)
187+
d[key] = v
188+
return d
189+
else:
190+
return l
174191

175192
def validate(self, caller=True):
176193
"""Recursively check the validity of the entries in a PlotlyList.
@@ -435,19 +452,25 @@ def strip_style(self):
435452
# print("'type' not in {0} for {1}".format(obj_key, key))
436453
pass
437454

438-
def get_data(self):
455+
def get_data(self, flatten=False):
439456
"""Returns the JSON for the plot with non-data elements stripped."""
440457
self.to_graph_objs()
441458
class_name = self.__class__.__name__
442459
obj_key = NAME_TO_KEY[class_name]
443460
d = dict()
444461
for key, val in list(self.items()):
445462
if isinstance(val, (PlotlyDict, PlotlyList)):
446-
d[key] = val.get_data()
463+
sub_data = val.get_data(flatten=flatten)
464+
if flatten:
465+
for sub_key, sub_val in sub_data.items():
466+
key_string = "{}.{}".format(key, sub_key)
467+
d[key_string] = sub_val
468+
else:
469+
d[key] = sub_data
447470
else:
448471
try:
449472
# TODO: Update the JSON
450-
if INFO[obj_key]['keymeta'][key]['key_type'] == 'data':
473+
if graph_objs_tools.value_is_data(obj_key, key, val):
451474
d[key] = val
452475
except KeyError:
453476
pass
@@ -456,8 +479,6 @@ def get_data(self):
456479
if isinstance(d[key], (dict, list)):
457480
if len(d[key]) == 0:
458481
del d[key]
459-
if len(d) == 1:
460-
d = list(d.values())[0]
461482
return d
462483

463484
def to_graph_objs(self, caller=True):
@@ -862,6 +883,40 @@ def to_graph_objs(self, caller=True): # TODO TODO TODO! check logic!
862883
)
863884
super(Data, self).to_graph_objs(caller=caller)
864885
Data.to_graph_objs = to_graph_objs # override method!
886+
887+
def get_data(self, flatten=False):
888+
"""
889+
890+
:param flatten:
891+
:return:
892+
893+
"""
894+
if flatten:
895+
self.to_graph_objs()
896+
data = [v.get_data(flatten=flatten) for v in self]
897+
d = {}
898+
for i, trace in enumerate(data):
899+
900+
# we want to give the traces helpful names
901+
# however, we need to be sure they're unique too...
902+
trace_name = trace.pop('name', 'trace_{}'.format(i))
903+
if trace_name in d:
904+
j = 1
905+
new_trace_name = "{}_{}".format(trace_name, j)
906+
while new_trace_name in d:
907+
new_trace_name = "{}_{}".format(trace_name, j)
908+
j += 1
909+
trace_name = new_trace_name
910+
911+
# finish up the dot-concatenation
912+
for k, v in trace.items():
913+
key = "{}.{}".format(trace_name, k)
914+
d[key] = v
915+
return d
916+
else:
917+
return super(< 341A span class=pl-v>Data, self).get_data(flatten=flatten)
918+
Data.get_data = get_data
919+
865920
return Data
866921

867922
Data = get_patched_data_class(Data)
@@ -936,6 +991,29 @@ def print_grid(self):
936991
print(grid_str)
937992
Figure.print_grid = print_grid
938993

994+
def get_data(self, flatten=False):
995+
"""
996+
Returns the JSON for the plot with non-data elements stripped.
997+
998+
Flattening may increase the utility of the result.
999+
1000+
:param (bool) flatten: {'a': {'b': ''}} --> {'a.b': ''}
1001+
:returns: (dict|list) Depending on (flat|unflat)
1002+
1003+
"""
1004+
data = super(Figure, self).get_data(flatten=flatten)
1005+
if flatten:
1006+
keys = data.keys()
1007+
for key in keys:
1008+
new_key = '.'.join(key.split('.')[1:])
1009+
old_data = data.pop(key)
1010+
if key.split('.')[0] == 'data':
1011+
data[new_key] = old_data
1012+
return data
1013+
else:
1014+
return data['data']
1015+
Figure.get_data = get_data
1016+
9391017
def append_trace(self, trace, row, col):
9401018
""" Helper function to add a data traces to your figure
9411019
that is bound to axes at the row, col index.

0 commit comments

Comments
 (0)
0