8000 feat: add configurable max table bytes and min table rows for DataFra… · kosiew/datafusion-python@f9b78fa · GitHub
[go: up one dir, main page]

Skip to content

Commit f9b78fa

Browse files
committed
feat: add configurable max table bytes and min table rows for DataFrame display
1 parent 818975b commit f9b78fa

File tree

2 files changed

+57
-31
lines changed

2 files changed

+57
-31
lines changed

python/datafusion/html_formatter.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ class DataFrameHtmlFormatter:
9898
style_provider: Custom provider for cell and header styles
9999
use_shared_styles: Whether to load styles and scripts only once per notebook
100100
session
101+
max_table_bytes: Maximum bytes to display for table presentation (default: 2MB)
102+
min_table_rows: Minimum number of table rows to display (default: 20)
101103
"""
102104

103105
# Class variable to track if styles have been loaded in the notebook
@@ -113,6 +115,8 @@ def __init__(
113115
show_truncation_message: bool = True,
114116
style_provider: Optional[StyleProvider] = None,
115117
use_shared_styles: bool = True,
118+
max_table_bytes: int = 2 * 1024 * 1024, # 2 MB
119+
min_table_rows: int = 20,
116120
) -> None:
117121
"""Initialize the HTML formatter.
118122
@@ -135,11 +139,16 @@ def __init__(
135139
is used.
136140
use_shared_styles : bool, default True
137141
Whether to use shared styles across multiple tables.
142+
max_table_bytes : int, default 2MB (2 * 1024 * 1024)
143+
Maximum bytes to display for table presentation.
144+
min_table_rows : int, default 20
145+
Minimum number of table rows to display.
138146
139147
Raises:
140148
------
141149
ValueError
142-
If max_cell_length, max_width, or max_height is not a positive integer.
150+
If max_cell_length, max_width, max_height, max_table_bytes, or min_table_rows
151+
is not a positive integer.
143152
TypeError
144153
If enable_cell_expansion, show_truncation_message, or use_shared_styles is
145154
not a boolean,
@@ -158,6 +167,12 @@ def __init__(
158167
if not isinstance(max_height, int) or max_height <= 0:
159168
msg = "max_height must be a positive integer"
160169
raise ValueError(msg)
170+
if not isinstance(max_table_bytes, int) or max_table_bytes <= 0:
171+
msg = "max_table_bytes must be a positive integer"
172+
raise ValueError(msg)
173+
if not isinstance(min_table_rows, int) or min_table_rows <= 0:
174+
msg = "min_table_rows must be a positive integer"
175+
raise ValueError(msg)
161176

162177
# Validate boolean parameters
163178
if not isinstance(enable_cell_expansion, bool):
@@ -188,6 +203,8 @@ def __init__(
188203
self.show_truncation_message = show_truncation_message
189204
self.style_provider = style_provider or DefaultStyleProvider()
190205
self.use_shared_styles = use_shared_styles
206+
self.max_table_bytes = max_table_bytes
207+
self.min_table_rows = min_table_rows
191208
# Registry for custom type formatters
192209
self._type_formatters: dict[type, CellFormatter] = {}
193210
# Custom cell builders

src/dataframe.rs

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ impl PyTableProvider {
7171
PyTable::new(table_provider)
7272
}
7373
}
74-
const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
75-
const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
7674

7775
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
7876
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -81,12 +79,16 @@ const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
8179
#[derive(Clone)]
8280
pub struct PyDataFrame {
8381
df: Arc<DataFrame>,
82+
display_config: Arc<PyDataframeDisplayConfig>,
8483
}
8584

8685
impl PyDataFrame {
8786
/// creates a new PyDataFrame
88-
pub fn new(df: DataFrame) -> Self {
89-
Self { df: Arc::new(df) }
87+
pub fn new(df: DataFrame, display_config: PyDataframeDisplayConfig) -> Self {
88+
Self {
89+
df: Arc::new(df),
90+
display_config: Arc::new(display_config),
91+
}
9092
}
9193
}
9294

@@ -116,7 +118,12 @@ impl PyDataFrame {
116118
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
117119
let (batches, has_more) = wait_for_future(
118120
py,
119-
collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
121+
collect_record_batches_to_display(
122+
self.df.as_ref().clone(),
123+
10,
124+
10,
125+
self.display_config.max_table_bytes,
126+
),
120127
)?;
121128
if batches.is_empty() {
122129
// This should not be reached, but do it for safety since we index into the vector below
@@ -139,8 +146,9 @@ impl PyDataFrame {
139146
py,
140147
collect_record_batches_to_display(
141148
self.df.as_ref().clone(),
142-
MIN_TABLE_ROWS_TO_DISPLAY,
149+
self.display_config.min_table_rows,
143150
usize::MAX,
151+
self.display_config.max_table_bytes,
144152
),
145153
)?;
146154
if batches.is_empty() {
@@ -181,7 +189,7 @@ impl PyDataFrame {
181189
fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
182190
let df = self.df.as_ref().clone();
183191
let stat_df = wait_for_future(py, df.describe())?;
184-
Ok(Self::new(stat_df))
192+
Ok(Self::new(stat_df, (*self.display_config).clone()))
185193
}
186194

187195
/// Returns the schema from the logical plan
@@ -211,31 +219,31 @@ impl PyDataFrame {
211219
fn select_columns(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
212220
let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
213221
let df = self.df.as_ref().clone().select_columns(&args)?;
214-
Ok(Self::new(df))
222+
Ok(Self::new(df, (*self.display_config).clone()))
215223
}
216224

217225
#[pyo3(signature = (*args))]
218226
fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
219227
let expr = args.into_iter().map(|e| e.into()).collect();
220228
let df = self.df.as_ref().clone().select(expr)?;
221-
Ok(Self::new(df))
229+
Ok(Self::new(df, (*self.display_config).clone()))
222230
}
223231

224232
#[pyo3(signature = (*args))]
225233
fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
226234
let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
227235
let df = self.df.as_ref().clone().drop_columns(&cols)?;
228-
Ok(Self::new(df))
236+
Ok(Self::new(df, (*self.display_config).clone()))
229237
}
230238

231239
fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
232240
let df = self.df.as_ref().clone().filter(predicate.into())?;
233-
Ok(Self::new(df))
241+
Ok(Self::new(df, (*self.display_config).clone()))
234242
}
235243

236244
fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult<Self> {
237245
let df = self.df.as_ref().clone().with_column(name, expr.into())?;
238-
Ok(Self::new(df))
246+
Ok(Self::new(df, (*self.display_config).clone()))
239247
}
240248

241249
fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
@@ -245,7 +253,7 @@ impl PyDataFrame {
245253
let name = format!("{}", expr.schema_name());
246254
df = df.with_column(name.as_str(), expr)?
247255
}
248-
Ok(Self::new(df))
256+
Ok(Self::new(df, (*self.display_config).clone()))
249257
}
250258

251259
/// Rename one column by applying a new projection. This is a no-op if the column to be
@@ -256,27 +264,27 @@ impl PyDataFrame {
256264
.as_ref()
257265
.clone()
258266
.with_column_renamed(old_name, new_name)?;
259-
Ok(Self::new(df))
267+
Ok(Self::new(df, (*self.display_config).clone()))
260268
}
261269

262270
fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
263271
let group_by = group_by.into_iter().map(|e| e.into()).collect();
264272
let aggs = aggs.into_iter().map(|e| e.into()).collect();
265273
let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
266-
Ok(Self::new(df))
274+
Ok(Self::new(df, (*self.display_config).clone()))
267275
}
268276

269277
#[pyo3(signature = (*exprs))]
270278
fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
271279
let exprs = to_sort_expressions(exprs);
272280
let df = self.df.as_ref().clone().sort(exprs)?;
273-
Ok(Self::new(df))
281+
Ok(Self::new(df, (*self.display_config).clone()))
274282
}
275283

276284
#[pyo3(signature = (count, offset=0))]
277285
fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
278286
let df = self.df.as_ref().clone().limit(offset, Some(count))?;
279-
Ok(Self::new(df))
287+
Ok(Self::new(df, (*self.display_config).clone()))
280288
}
281289

282290
/// Executes the plan, returning a list of `RecordBatch`es.
@@ -293,7 +301,7 @@ impl PyDataFrame {
293301
/// Cache DataFrame.
294302
fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
295303
let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
296-
Ok(Self::new(df))
304+
Ok(Self::new(df, (*self.display_config).clone()))
297305
}
298306

299307
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
@@ -318,7 +326,7 @@ impl PyDataFrame {
318326
/// Filter out duplicate rows
319327
fn distinct(&self) -> PyDataFusionResult<Self> {
320328
let df = self.df.as_ref().clone().distinct()?;
321-
Ok(Self::new(df))
329+
Ok(Self::new(df, (*self.display_config).clone()))
322330
}
323331

324332
fn join(
@@ -352,7 +360,7 @@ impl PyDataFrame {
352360
&right_keys,
353361
None,
354362
)?;
355-
Ok(Self::new(df))
363+
Ok(Self::new(df, (*self.display_config).clone()))
356364
}
357365

358366
fn join_on(
@@ -381,7 +389,7 @@ impl PyDataFrame {
381389
.as_ref()
382390
.clone()
383391
.join_on(right.df.as_ref().clone(), join_type, exprs)?;
384-
Ok(Self::new(df))
392+
Ok(Self::new(df, (*self.display_config).clone()))
385393
}
386394

387395
/// Print the query plan
@@ -414,7 +422,7 @@ impl PyDataFrame {
414422
.as_ref()
415423
.clone()
416424
.repartition(Partitioning::RoundRobinBatch(num))?;
417-
Ok(Self::new(new_df))
425+
Ok(Self::new(new_df, (*self.display_config).clone()))
418426
}
419427

420428
/// Repartition a `DataFrame` based on a logical partitioning scheme.
@@ -426,7 +434,7 @@ impl PyDataFrame {
426434
.as_ref()
427435
.clone()
428436
.repartition(Partitioning::Hash(expr, num))?;
429-
Ok(Self::new(new_df))
437+
Ok(Self::new(new_df, (*self.display_config).clone()))
430438
}
431439

432440
/// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
@@ -442,7 +450,7 @@ impl PyDataFrame {
442450
self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
443451
};
444452

445-
Ok(Self::new(new_df))
453+
Ok(Self::new(new_df, (*self.display_config).clone()))
446454
}
447455

448456
/// Calculate the distinct union of two `DataFrame`s. The
@@ -453,7 +461,7 @@ impl PyDataFrame {
453461
.as_ref()
454462
.clone()
455463
.union_distinct(py_df.df.as_ref().clone())?;
456-
Ok(Self::new(new_df))
464+
Ok(Self::new(new_df, (*self.display_config).clone()))
457465
}
458466

459467
#[pyo3(signature = (column, preserve_nulls=true))]
@@ -494,13 +502,13 @@ impl PyDataFrame {
494502
.as_ref()
495503
.clone()
496504
.intersect(py_df.df.as_ref().clone())?;
497-
Ok(Self::new(new_df))
505+
Ok(Self::new(new_df, (*self.display_config).clone()))
498506
}
499507

500508
/// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
501509
fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
502510
let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
503-
Ok(Self::new(new_df))
511+
Ok(Self::new(new_df, (*self.display_config).clone()))
504512
}
505513

506514
/// Write a `DataFrame` to a CSV file.
@@ -798,6 +806,7 @@ async fn collect_record_batches_to_display(
798806
df: DataFrame,
799807
min_rows: usize,
800808
max_rows: usize,
809+
max_table_bytes: usize,
801810
) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
802811
let partitioned_stream = df.execute_stream_partitioned().await?;
803812
let mut stream = futures::stream::iter(partitioned_stream).flatten();
@@ -806,7 +815,7 @@ async fn collect_record_batches_to_display(
806815
let mut record_batches = Vec::default();
807816
let mut has_more = false;
808817

809-
while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
818+
while (size_estimate_so_far < max_table_bytes && rows_so_far < max_rows)
810819
|| rows_so_far < min_rows
811820
{
812821
let mut rb = match stream.next().await {
@@ -821,8 +830,8 @@ async fn collect_record_batches_to_display(
821830
if rows_in_rb > 0 {
822831
size_estimate_so_far += rb.get_array_memory_size();
823832

824-
if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
825-
let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
833+
if size_estimate_so_far > max_table_bytes {
834+
let ratio = max_table_bytes as f32 / size_estimate_so_far as f32;
826835
let total_rows = rows_in_rb + rows_so_far;
827836

828837
let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;

0 commit comments

Comments
 (0)
0