From f9b78fa3180c5d6c20eaa3b6d0af7426d7084093 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 18:58:05 +0800 Subject: [PATCH 01/40] feat: add configurable max table bytes and min table rows for DataFrame display --- python/datafusion/html_formatter.py | 19 +++++++- src/dataframe.rs | 69 ++++++++++++++++------------- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/python/datafusion/html_formatter.py b/python/datafusion/html_formatter.py index a50e14fd5..2eb116cab 100644 --- a/python/datafusion/html_formatter.py +++ b/python/datafusion/html_formatter.py @@ -98,6 +98,8 @@ class DataFrameHtmlFormatter: style_provider: Custom provider for cell and header styles use_shared_styles: Whether to load styles and scripts only once per notebook session + max_table_bytes: Maximum bytes to display for table presentation (default: 2MB) + min_table_rows: Minimum number of table rows to display (default: 20) """ # Class variable to track if styles have been loaded in the notebook @@ -113,6 +115,8 @@ def __init__( show_truncation_message: bool = True, style_provider: Optional[StyleProvider] = None, use_shared_styles: bool = True, + max_table_bytes: int = 2 * 1024 * 1024, # 2 MB + min_table_rows: int = 20, ) -> None: """Initialize the HTML formatter. @@ -135,11 +139,16 @@ def __init__( is used. use_shared_styles : bool, default True Whether to use shared styles across multiple tables. + max_table_bytes : int, default 2MB (2 * 1024 * 1024) + Maximum bytes to display for table presentation. + min_table_rows : int, default 20 + Minimum number of table rows to display. Raises: ------ ValueError - If max_cell_length, max_width, or max_height is not a positive integer. + If max_cell_length, max_width, max_height, max_table_bytes, or min_table_rows + is not a positive integer. TypeError If enable_cell_expansion, show_truncation_message, or use_shared_styles is not a boolean, @@ -158,6 +167,12 @@ def __init__( if not isinstance(max_height, int) or max_height <= 0: msg = "max_height must be a positive integer" raise ValueError(msg) + if not isinstance(max_table_bytes, int) or max_table_bytes <= 0: + msg = "max_table_bytes must be a positive integer" + raise ValueError(msg) + if not isinstance(min_table_rows, int) or min_table_rows <= 0: + msg = "min_table_rows must be a positive integer" + raise ValueError(msg) # Validate boolean parameters if not isinstance(enable_cell_expansion, bool): @@ -188,6 +203,8 @@ def __init__( self.show_truncation_message = show_truncation_message self.style_provider = style_provider or DefaultStyleProvider() self.use_shared_styles = use_shared_styles + self.max_table_bytes = max_table_bytes + self.min_table_rows = min_table_rows # Registry for custom type formatters self._type_formatters: dict[type, CellFormatter] = {} # Custom cell builders diff --git a/src/dataframe.rs b/src/dataframe.rs index 9b610b5d7..e9f73a70d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -71,8 +71,6 @@ impl PyTableProvider { PyTable::new(table_provider) } } -const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB -const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. @@ -81,12 +79,16 @@ const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20; #[derive(Clone)] pub struct PyDataFrame { df: Arc, + display_config: Arc, } impl PyDataFrame { /// creates a new PyDataFrame - pub fn new(df: DataFrame) -> Self { - Self { df: Arc::new(df) } + pub fn new(df: DataFrame, display_config: PyDataframeDisplayConfig) -> Self { + Self { + df: Arc::new(df), + display_config: Arc::new(display_config), + } } } @@ -116,7 +118,12 @@ impl PyDataFrame { fn __repr__(&self, py: Python) -> PyDataFusionResult { let (batches, has_more) = wait_for_future( py, - collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10), + collect_record_batches_to_display( + self.df.as_ref().clone(), + 10, + 10, + self.display_config.max_table_bytes, + ), )?; if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below @@ -139,8 +146,9 @@ impl PyDataFrame { py, collect_record_batches_to_display( self.df.as_ref().clone(), - MIN_TABLE_ROWS_TO_DISPLAY, + self.display_config.min_table_rows, usize::MAX, + self.display_config.max_table_bytes, ), )?; if batches.is_empty() { @@ -181,7 +189,7 @@ impl PyDataFrame { fn describe(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); let stat_df = wait_for_future(py, df.describe())?; - Ok(Self::new(stat_df)) + Ok(Self::new(stat_df, (*self.display_config).clone())) } /// Returns the schema from the logical plan @@ -211,31 +219,31 @@ impl PyDataFrame { fn select_columns(&self, args: Vec) -> PyDataFusionResult { let args = args.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().select_columns(&args)?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } #[pyo3(signature = (*args))] fn select(&self, args: Vec) -> PyDataFusionResult { let expr = args.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().select(expr)?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } #[pyo3(signature = (*args))] fn drop(&self, args: Vec) -> PyDataFusionResult { let cols = args.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().drop_columns(&cols)?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } fn filter(&self, predicate: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().filter(predicate.into())?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().with_column(name, expr.into())?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } fn with_columns(&self, exprs: Vec) -> PyDataFusionResult { @@ -245,7 +253,7 @@ impl PyDataFrame { let name = format!("{}", expr.schema_name()); df = df.with_column(name.as_str(), expr)? } - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } /// Rename one column by applying a new projection. This is a no-op if the column to be @@ -256,27 +264,27 @@ impl PyDataFrame { .as_ref() .clone() .with_column_renamed(old_name, new_name)?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyDataFusionResult { let group_by = group_by.into_iter().map(|e| e.into()).collect(); let aggs = aggs.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().aggregate(group_by, aggs)?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } #[pyo3(signature = (*exprs))] fn sort(&self, exprs: Vec) -> PyDataFusionResult { let exprs = to_sort_expressions(exprs); let df = self.df.as_ref().clone().sort(exprs)?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } #[pyo3(signature = (count, offset=0))] fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult { let df = self.df.as_ref().clone().limit(offset, Some(count))?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } /// Executes the plan, returning a list of `RecordBatch`es. @@ -293,7 +301,7 @@ impl PyDataFrame { /// Cache DataFrame. fn cache(&self, py: Python) -> PyDataFusionResult { let df = wait_for_future(py, self.df.as_ref().clone().cache())?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch @@ -318,7 +326,7 @@ impl PyDataFrame { /// Filter out duplicate rows fn distinct(&self) -> PyDataFusionResult { let df = self.df.as_ref().clone().distinct()?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } fn join( @@ -352,7 +360,7 @@ impl PyDataFrame { &right_keys, None, )?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } fn join_on( @@ -381,7 +389,7 @@ impl PyDataFrame { .as_ref() .clone() .join_on(right.df.as_ref().clone(), join_type, exprs)?; - Ok(Self::new(df)) + Ok(Self::new(df, (*self.display_config).clone())) } /// Print the query plan @@ -414,7 +422,7 @@ impl PyDataFrame { .as_ref() .clone() .repartition(Partitioning::RoundRobinBatch(num))?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, (*self.display_config).clone())) } /// Repartition a `DataFrame` based on a logical partitioning scheme. @@ -426,7 +434,7 @@ impl PyDataFrame { .as_ref() .clone() .repartition(Partitioning::Hash(expr, num))?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, (*self.display_config).clone())) } /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The @@ -442,7 +450,7 @@ impl PyDataFrame { self.df.as_ref().clone().union(py_df.df.as_ref().clone())? }; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, (*self.display_config).clone())) } /// Calculate the distinct union of two `DataFrame`s. The @@ -453,7 +461,7 @@ impl PyDataFrame { .as_ref() .clone() .union_distinct(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, (*self.display_config).clone())) } #[pyo3(signature = (column, preserve_nulls=true))] @@ -494,13 +502,13 @@ impl PyDataFrame { .as_ref() .clone() .intersect(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, (*self.display_config).clone())) } /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult { let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, (*self.display_config).clone())) } /// Write a `DataFrame` to a CSV file. @@ -798,6 +806,7 @@ async fn collect_record_batches_to_display( df: DataFrame, min_rows: usize, max_rows: usize, + max_table_bytes: usize, ) -> Result<(Vec, bool), DataFusionError> { let partitioned_stream = df.execute_stream_partitioned().await?; let mut stream = futures::stream::iter(partitioned_stream).flatten(); @@ -806,7 +815,7 @@ async fn collect_record_batches_to_display( let mut record_batches = Vec::default(); let mut has_more = false; - while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows) + while (size_estimate_so_far < max_table_bytes && rows_so_far < max_rows) || rows_so_far < min_rows { let mut rb = match stream.next().await { @@ -821,8 +830,8 @@ async fn collect_record_batches_to_display( if rows_in_rb > 0 { size_estimate_so_far += rb.get_array_memory_size(); - if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY { - let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32; + if size_estimate_so_far > max_table_bytes { + let ratio = max_table_bytes as f32 / size_estimate_so_far as f32; let total_rows = rows_in_rb + rows_so_far; let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize; From 4d8fa38007b7dfe689344fc44d5392a8734c64f5 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 18:58:23 +0800 Subject: [PATCH 02/40] Revert "feat: add configurable max table bytes and min table rows for DataFrame display" This reverts commit f9b78fa3180c5d6c20eaa3b6d0af7426d7084093. --- python/datafusion/html_formatter.py | 19 +------- src/dataframe.rs | 69 +++++++++++++---------------- 2 files changed, 31 insertions(+), 57 deletions(-) diff --git a/python/datafusion/html_formatter.py b/python/datafusion/html_formatter.py index 2eb116cab..a50e14fd5 100644 --- a/python/datafusion/html_formatter.py +++ b/python/datafusion/html_formatter.py @@ -98,8 +98,6 @@ class DataFrameHtmlFormatter: style_provider: Custom provider for cell and header styles use_shared_styles: Whether to load styles and scripts only once per notebook session - max_table_bytes: Maximum bytes to display for table presentation (default: 2MB) - min_table_rows: Minimum number of table rows to display (default: 20) """ # Class variable to track if styles have been loaded in the notebook @@ -115,8 +113,6 @@ def __init__( show_truncation_message: bool = True, style_provider: Optional[StyleProvider] = None, use_shared_styles: bool = True, - max_table_bytes: int = 2 * 1024 * 1024, # 2 MB - min_table_rows: int = 20, ) -> None: """Initialize the HTML formatter. @@ -139,16 +135,11 @@ def __init__( is used. use_shared_styles : bool, default True Whether to use shared styles across multiple tables. - max_table_bytes : int, default 2MB (2 * 1024 * 1024) - Maximum bytes to display for table presentation. - min_table_rows : int, default 20 - Minimum number of table rows to display. Raises: ------ ValueError - If max_cell_length, max_width, max_height, max_table_bytes, or min_table_rows - is not a positive integer. + If max_cell_length, max_width, or max_height is not a positive integer. TypeError If enable_cell_expansion, show_truncation_message, or use_shared_styles is not a boolean, @@ -167,12 +158,6 @@ def __init__( if not isinstance(max_height, int) or max_height <= 0: msg = "max_height must be a positive integer" raise ValueError(msg) - if not isinstance(max_table_bytes, int) or max_table_bytes <= 0: - msg = "max_table_bytes must be a positive integer" - raise ValueError(msg) - if not isinstance(min_table_rows, int) or min_table_rows <= 0: - msg = "min_table_rows must be a positive integer" - raise ValueError(msg) # Validate boolean parameters if not isinstance(enable_cell_expansion, bool): @@ -203,8 +188,6 @@ def __init__( self.show_truncation_message = show_truncation_message self.style_provider = style_provider or DefaultStyleProvider() self.use_shared_styles = use_shared_styles - self.max_table_bytes = max_table_bytes - self.min_table_rows = min_table_rows # Registry for custom type formatters self._type_formatters: dict[type, CellFormatter] = {} # Custom cell builders diff --git a/src/dataframe.rs b/src/dataframe.rs index e9f73a70d..9b610b5d7 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -71,6 +71,8 @@ impl PyTableProvider { PyTable::new(table_provider) } } +const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB +const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. @@ -79,16 +81,12 @@ impl PyTableProvider { #[derive(Clone)] pub struct PyDataFrame { df: Arc, - display_config: Arc, } impl PyDataFrame { /// creates a new PyDataFrame - pub fn new(df: DataFrame, display_config: PyDataframeDisplayConfig) -> Self { - Self { - df: Arc::new(df), - display_config: Arc::new(display_config), - } + pub fn new(df: DataFrame) -> Self { + Self { df: Arc::new(df) } } } @@ -118,12 +116,7 @@ impl PyDataFrame { fn __repr__(&self, py: Python) -> PyDataFusionResult { let (batches, has_more) = wait_for_future( py, - collect_record_batches_to_display( - self.df.as_ref().clone(), - 10, - 10, - self.display_config.max_table_bytes, - ), + collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10), )?; if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below @@ -146,9 +139,8 @@ impl PyDataFrame { py, collect_record_batches_to_display( self.df.as_ref().clone(), - self.display_config.min_table_rows, + MIN_TABLE_ROWS_TO_DISPLAY, usize::MAX, - self.display_config.max_table_bytes, ), )?; if batches.is_empty() { @@ -189,7 +181,7 @@ impl PyDataFrame { fn describe(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); let stat_df = wait_for_future(py, df.describe())?; - Ok(Self::new(stat_df, (*self.display_config).clone())) + Ok(Self::new(stat_df)) } /// Returns the schema from the logical plan @@ -219,31 +211,31 @@ impl PyDataFrame { fn select_columns(&self, args: Vec) -> PyDataFusionResult { let args = args.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().select_columns(&args)?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } #[pyo3(signature = (*args))] fn select(&self, args: Vec) -> PyDataFusionResult { let expr = args.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().select(expr)?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } #[pyo3(signature = (*args))] fn drop(&self, args: Vec) -> PyDataFusionResult { let cols = args.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().drop_columns(&cols)?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } fn filter(&self, predicate: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().filter(predicate.into())?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().with_column(name, expr.into())?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } fn with_columns(&self, exprs: Vec) -> PyDataFusionResult { @@ -253,7 +245,7 @@ impl PyDataFrame { let name = format!("{}", expr.schema_name()); df = df.with_column(name.as_str(), expr)? } - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } /// Rename one column by applying a new projection. This is a no-op if the column to be @@ -264,27 +256,27 @@ impl PyDataFrame { .as_ref() .clone() .with_column_renamed(old_name, new_name)?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyDataFusionResult { let group_by = group_by.into_iter().map(|e| e.into()).collect(); let aggs = aggs.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().aggregate(group_by, aggs)?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } #[pyo3(signature = (*exprs))] fn sort(&self, exprs: Vec) -> PyDataFusionResult { let exprs = to_sort_expressions(exprs); let df = self.df.as_ref().clone().sort(exprs)?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } #[pyo3(signature = (count, offset=0))] fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult { let df = self.df.as_ref().clone().limit(offset, Some(count))?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } /// Executes the plan, returning a list of `RecordBatch`es. @@ -301,7 +293,7 @@ impl PyDataFrame { /// Cache DataFrame. fn cache(&self, py: Python) -> PyDataFusionResult { let df = wait_for_future(py, self.df.as_ref().clone().cache())?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch @@ -326,7 +318,7 @@ impl PyDataFrame { /// Filter out duplicate rows fn distinct(&self) -> PyDataFusionResult { let df = self.df.as_ref().clone().distinct()?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } fn join( @@ -360,7 +352,7 @@ impl PyDataFrame { &right_keys, None, )?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } fn join_on( @@ -389,7 +381,7 @@ impl PyDataFrame { .as_ref() .clone() .join_on(right.df.as_ref().clone(), join_type, exprs)?; - Ok(Self::new(df, (*self.display_config).clone())) + Ok(Self::new(df)) } /// Print the query plan @@ -422,7 +414,7 @@ impl PyDataFrame { .as_ref() .clone() .repartition(Partitioning::RoundRobinBatch(num))?; - Ok(Self::new(new_df, (*self.display_config).clone())) + Ok(Self::new(new_df)) } /// Repartition a `DataFrame` based on a logical partitioning scheme. @@ -434,7 +426,7 @@ impl PyDataFrame { .as_ref() .clone() .repartition(Partitioning::Hash(expr, num))?; - Ok(Self::new(new_df, (*self.display_config).clone())) + Ok(Self::new(new_df)) } /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The @@ -450,7 +442,7 @@ impl PyDataFrame { self.df.as_ref().clone().union(py_df.df.as_ref().clone())? }; - Ok(Self::new(new_df, (*self.display_config).clone())) + Ok(Self::new(new_df)) } /// Calculate the distinct union of two `DataFrame`s. The @@ -461,7 +453,7 @@ impl PyDataFrame { .as_ref() .clone() .union_distinct(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df, (*self.display_config).clone())) + Ok(Self::new(new_df)) } #[pyo3(signature = (column, preserve_nulls=true))] @@ -502,13 +494,13 @@ impl PyDataFrame { .as_ref() .clone() .intersect(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df, (*self.display_config).clone())) + Ok(Self::new(new_df)) } /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult { let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df, (*self.display_config).clone())) + Ok(Self::new(new_df)) } /// Write a `DataFrame` to a CSV file. @@ -806,7 +798,6 @@ async fn collect_record_batches_to_display( df: DataFrame, min_rows: usize, max_rows: usize, - max_table_bytes: usize, ) -> Result<(Vec, bool), DataFusionError> { let partitioned_stream = df.execute_stream_partitioned().await?; let mut stream = futures::stream::iter(partitioned_stream).flatten(); @@ -815,7 +806,7 @@ async fn collect_record_batches_to_display( let mut record_batches = Vec::default(); let mut has_more = false; - while (size_estimate_so_far < max_table_bytes && rows_so_far < max_rows) + while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows) || rows_so_far < min_rows { let mut rb = match stream.next().await { @@ -830,8 +821,8 @@ async fn collect_record_batches_to_display( if rows_in_rb > 0 { size_estimate_so_far += rb.get_array_memory_size(); - if size_estimate_so_far > max_table_bytes { - let ratio = max_table_bytes as f32 / size_estimate_so_far as f32; + if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY { + let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32; let total_rows = rows_in_rb + rows_so_far; let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize; From a9178feb501c11c3c9ee0a20f71418a8ea4168f7 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 19:20:13 +0800 Subject: [PATCH 03/40] feat: add FormatterConfig for configurable DataFrame display options --- src/dataframe.rs | 53 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/dataframe.rs b/src/dataframe.rs index 9b610b5d7..cef950988 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -71,9 +71,62 @@ impl PyTableProvider { PyTable::new(table_provider) } } + +/// Configuration for DataFrame display formatting +#[derive(Debug, Clone)] +pub struct FormatterConfig { + /// Maximum memory in bytes to use for display (default: 2MB) + pub max_bytes: usize, + /// Minimum number of rows to display (default: 20) + pub min_rows: usize, + /// Number of rows to include in __repr__ output (default: 10) + pub repr_rows: usize, +} + +impl Default for FormatterConfig { + fn default() -> Self { + Self { + max_bytes: 2 * 1024 * 1024, // 2MB + min_rows: 20, + repr_rows: 10, + } + } +} + +// Keep constants for backward compatibility const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20; +fn get_formatter_config(py: Python) -> PyResult { + let formatter_module = py.import("datafusion.html_formatter")?; + let get_formatter = formatter_module.getattr("get_formatter")?; + let formatter = get_formatter.call0()?; + + // Get max_memory_bytes (or fallback to default) + let max_bytes = formatter + .getattr("max_memory_bytes") + .and_then(|v| v.extract::()) + .unwrap_or(FormatterConfig::default().max_bytes); + + // Get min_rows_display (or fallback to default) + let min_rows = formatter + .getattr("min_rows_display") + .and_then(|v| v.extract::()) + .unwrap_or(FormatterConfig::default().min_rows); + + // Get repr_rows (or fallback to default) + let repr_rows = formatter + .getattr("repr_rows") + .and_then(|v| v.extract::()) + .unwrap_or(FormatterConfig::default().repr_rows); + + Ok(FormatterConfig { + max_bytes, + min_rows, + repr_rows, + }) +} + /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. /// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment. From d0209cf7d90400675f09b490cad0ca700d74f4c7 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 19:23:42 +0800 Subject: [PATCH 04/40] refactor: simplify attribute extraction in get_formatter_config function --- src/dataframe.rs | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index cef950988..ea838d845 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -102,23 +102,22 @@ fn get_formatter_config(py: Python) -> PyResult { let get_formatter = formatter_module.getattr("get_formatter")?; let formatter = get_formatter.call0()?; - // Get max_memory_bytes (or fallback to default) - let max_bytes = formatter - .getattr("max_memory_bytes") - .and_then(|v| v.extract::()) - .unwrap_or(FormatterConfig::default().max_bytes); - - // Get min_rows_display (or fallback to default) - let min_rows = formatter - .getattr("min_rows_display") - .and_then(|v| v.extract::()) - .unwrap_or(FormatterConfig::default().min_rows); - - // Get repr_rows (or fallback to default) - let repr_rows = formatter - .getattr("repr_rows") - .and_then(|v| v.extract::()) - .unwrap_or(FormatterConfig::default().repr_rows); + // Helper function to extract attributes with fallback to default + fn get_attr<'a>( + formatter: &'a Bound<'a, PyAny>, + attr_name: &str, + default_value: usize, + ) -> usize { + formatter + .getattr(attr_name) + .and_then(|v| v.extract::()) + .unwrap_or(default_value) + } + + let default_config = FormatterConfig::default(); + let max_bytes = get_attr(&formatter, "max_memory_bytes", default_config.max_bytes); + let min_rows = get_attr(&formatter, "min_rows_display", default_config.min_rows); + let repr_rows = get_attr(&formatter, "repr_rows", default_config.repr_rows); Ok(FormatterConfig { max_bytes, From 2ef013f1d9f9af3113e2a16b5e92d2274f9cd3e3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 19:33:01 +0800 Subject: [PATCH 05/40] refactor: remove hardcoded constants and use FormatterConfig for display options --- src/dataframe.rs | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index ea838d845..e6dd4f70d 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -93,10 +93,6 @@ impl Default for FormatterConfig { } } -// Keep constants for backward compatibility -const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB -const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20; - fn get_formatter_config(py: Python) -> PyResult { let formatter_module = py.import("datafusion.html_formatter")?; let get_formatter = formatter_module.getattr("get_formatter")?; @@ -166,9 +162,14 @@ impl PyDataFrame { } fn __repr__(&self, py: Python) -> PyDataFusionResult { + let config = get_formatter_config(py)?; let (batches, has_more) = wait_for_future( py, - collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10), + collect_record_batches_to_display( + self.df.as_ref().clone(), + config.repr_rows, + config.repr_rows, + ), )?; if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below @@ -187,11 +188,12 @@ impl PyDataFrame { } fn _repr_html_(&self, py: Python) -> PyDataFusionResult { + let config = get_formatter_config(py)?; let (batches, has_more) = wait_for_future( py, collect_record_batches_to_display( self.df.as_ref().clone(), - MIN_TABLE_ROWS_TO_DISPLAY, + config.min_rows, usize::MAX, ), )?; @@ -851,6 +853,9 @@ async fn collect_record_batches_to_display( min_rows: usize, max_rows: usize, ) -> Result<(Vec, bool), DataFusionError> { + let config = FormatterConfig::default(); + let max_bytes = config.max_bytes; + let partitioned_stream = df.execute_stream_partitioned().await?; let mut stream = futures::stream::iter(partitioned_stream).flatten(); let mut size_estimate_so_far = 0; @@ -858,9 +863,7 @@ async fn collect_record_batches_to_display( let mut record_batches = Vec::default(); let mut has_more = false; - while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows) - || rows_so_far < min_rows - { + while (size_estimate_so_far < max_bytes && rows_so_far < max_rows) || rows_so_far < min_rows { let mut rb = match stream.next().await { None => { break; @@ -873,8 +876,8 @@ async fn collect_record_batches_to_display( if rows_in_rb > 0 { size_estimate_so_far += rb.get_array_memory_size(); - if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY { - let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32; + if size_estimate_so_far > max_bytes { + let ratio = max_bytes as f32 / size_estimate_so_far as f32; let total_rows = rows_in_rb + rows_so_far; let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize; From bea52a31a3b6c2ee481a9d21d28ffa00674e9dd6 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 21:02:20 +0800 Subject: [PATCH 06/40] refactor: simplify record batch collection by using FormatterConfig for display options --- src/dataframe.rs | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index e6dd4f70d..62069461a 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -165,11 +165,7 @@ impl PyDataFrame { let config = get_formatter_config(py)?; let (batches, has_more) = wait_for_future( py, - collect_record_batches_to_display( - self.df.as_ref().clone(), - config.repr_rows, - config.repr_rows, - ), + collect_record_batches_to_display(self.df.as_ref().clone(), config), )?; if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below @@ -191,11 +187,7 @@ impl PyDataFrame { let config = get_formatter_config(py)?; let (batches, has_more) = wait_for_future( py, - collect_record_batches_to_display( - self.df.as_ref().clone(), - config.min_rows, - usize::MAX, - ), + collect_record_batches_to_display(self.df.as_ref().clone(), config), )?; if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below @@ -850,11 +842,11 @@ fn record_batch_into_schema( /// rows, set min_rows == max_rows. async fn collect_record_batches_to_display( df: DataFrame, - min_rows: usize, - max_rows: usize, + config: FormatterConfig, ) -> Result<(Vec, bool), DataFusionError> { - let config = FormatterConfig::default(); let max_bytes = config.max_bytes; + let min_rows = config.min_rows; + let max_rows = config.repr_rows; let partitioned_stream = df.execute_stream_partitioned().await?; let mut stream = futures::stream::iter(partitioned_stream).flatten(); From ce15f1dcf8d595044cd0a90f76f3612871cbd80e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 21:02:30 +0800 Subject: [PATCH 07/40] feat: add max_memory_bytes, min_rows_display, and repr_rows parameters to DataFrameHtmlFormatter --- python/datafusion/html_formatter.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/python/datafusion/html_formatter.py b/python/datafusion/html_formatter.py index a50e14fd5..065b7262c 100644 --- a/python/datafusion/html_formatter.py +++ b/python/datafusion/html_formatter.py @@ -91,6 +91,9 @@ class DataFrameHtmlFormatter: max_cell_length: Maximum characters to display in a cell before truncation max_width: Maximum width of the HTML table in pixels max_height: Maximum height of the HTML table in pixels + max_memory_bytes: Maximum memory in bytes for rendered data (default: 2MB) + min_rows_display: Minimum number of rows to display + repr_rows: Default number of rows to display in repr output enable_cell_expansion: Whether to add expand/collapse buttons for long cell values custom_css: Additional CSS to include in the HTML output @@ -108,6 +111,9 @@ def __init__( max_cell_length: int = 25, max_width: int = 1000, max_height: int = 300, + max_memory_bytes: int = 2 * 1024 * 1024, # 2 MB + min_rows_display: int = 20, + repr_rows: int = 10, enable_cell_expansion: bool = True, custom_css: Optional[str] = None, show_truncation_message: bool = True, @@ -124,6 +130,12 @@ def __init__( Maximum width of the displayed table in pixels. max_height : int, default 300 Maximum height of the displayed table in pixels. + max_memory_bytes : int, default 2097152 (2MB) + Maximum memory in bytes for rendered data. + min_rows_display : int, default 20 + Minimum number of rows to display. + repr_rows : int, default 10 + Default number of rows to display in repr output. enable_cell_expansion : bool, default True Whether to allow cells to expand when clicked. custom_css : str, optional @@ -139,7 +151,8 @@ def __init__( Raises: ------ ValueError - If max_cell_length, max_width, or max_height is not a positive integer. + If max_cell_length, max_width, max_height, max_memory_bytes, + min_rows_display, or repr_rows is not a positive integer. TypeError If enable_cell_expansion, show_truncation_message, or use_shared_styles is not a boolean, @@ -158,6 +171,15 @@ def __init__( if not isinstance(max_height, int) or max_height <= 0: msg = "max_height must be a positive integer" raise ValueError(msg) + if not isinstance(max_memory_bytes, int) or max_memory_bytes <= 0: + msg = "max_memory_bytes must be a positive integer" + raise ValueError(msg) + if not isinstance(min_rows_display, int) or min_rows_display <= 0: + msg = "min_rows_display must be a positive integer" + raise ValueError(msg) + if not isinstance(repr_rows, int) or repr_rows <= 0: + msg = "repr_rows must be a positive integer" + raise ValueError(msg) # Validate boolean parameters if not isinstance(enable_cell_expansion, bool): @@ -183,6 +205,9 @@ def __init__( self.max_cell_length = max_cell_length self.max_width = max_width self.max_height = max_height + self.max_memory_bytes = max_memory_bytes + self.min_rows_display = min_rows_display + self.repr_rows = repr_rows self.enable_cell_expansion = enable_cell_expansion self.custom_css = custom_css self.show_truncation_message = show_truncation_message From e089d7b282e53e587116b11d92760e6d292ec871 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 21:15:50 +0800 Subject: [PATCH 08/40] feat: add tests for HTML formatter row display settings and memory limit --- python/tests/test_dataframe.py | 136 ++++++++++++++++----------------- 1 file changed, 68 insertions(+), 68 deletions(-) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 464b884db..2a6f7ec5a 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -679,6 +679,9 @@ def test_html_formatter_configuration(df, clean_formatter_state): max_width=500, max_height=200, enable_cell_expansion=False, + max_memory_bytes=1024 * 1024, # 1 MB + min_rows_display=15, + repr_rows=5, ) html_output = df._repr_html_() @@ -690,6 +693,71 @@ def test_html_formatter_configuration(df, clean_formatter_state): assert "expandable-container" not in html_output +def test_html_formatter_row_display_settings(clean_formatter_state): + """Test that min_rows_display and repr_rows affect the output.""" + ctx = SessionContext() + + # Create a dataframe with 30 rows + data = list(range(30)) + batch = pa.RecordBatch.from_arrays( + [pa.array(data)], + names=["value"], + ) + df = ctx.create_dataframe([[batch]]) + + # Test with default settings (should use repr_rows) + configure_formatter(repr_rows=7, min_rows_display=20) + html_default = df._repr_html_() + + # Verify we only show repr_rows (7) rows in the output + # by counting the number of value cells + value_cells = re.findall(r"]*>\s*\d+\s*", html_default) + assert len(value_cells) == 7 + assert "... with 23 more rows" in html_default + + # Configure to show all rows since it's below min_rows_display + reset_formatter() + configure_formatter(repr_rows=5, min_rows_display=50) + html_all = df._repr_html_() + + # Verify we show all rows + value_cells = re.findall(r"]*>\s*\d+\s*", html_all) + assert len(value_cells) == 30 + assert "... with" not in html_all + + +def test_html_formatter_memory_limit(clean_formatter_state): + """Test that max_memory_bytes limits the HTML rendering.""" + ctx = SessionContext() + + # Create a large string that will consume substantial memory when rendered + large_string = "x" * 100000 + + # Create a dataframe with 10 rows of large strings + batch = pa.RecordBatch.from_arrays( + [pa.array([large_string] * 10)], + names=["large_value"], + ) + df = ctx.create_dataframe([[batch]]) + + # Set very small memory limit + configure_formatter(max_memory_bytes=1000) # 1KB + + html_limited = df._repr_html_() + + # Verify that memory limit warning is included in the output + assert "Memory usage limit reached" in html_limited + + # Now with larger limit, should display normally + reset_formatter() + configure_formatter(max_memory_bytes=10 * 1024 * 1024) # 10MB + + html_full = df._repr_html_() + + # Verify no memory limit warning + assert "Memory usage limit reached" not in html_full + + def test_html_formatter_custom_style_provider(df, clean_formatter_state): """Test using custom style providers with the HTML formatter.""" @@ -771,74 +839,6 @@ def custom_cell_builder(value, row, col, table_id): r']*>(\d+)-low', html_output ) mid_cells = re.findall( - r']*>(\d+)-mid', html_output - ) - high_cells = re.findall( - r']*>(\d+)-high', html_output - ) - - # Sort the extracted values for consistent comparison - low_cells = sorted(map(int, low_cells)) - mid_cells = sorted(map(int, mid_cells)) - high_cells = sorted(map(int, high_cells)) - - # Verify specific values have the correct styling applied - assert low_cells == [1, 2] # Values < 3 - assert mid_cells == [3, 4, 5, 5] # Values 3-5 - assert high_cells == [6, 8, 8] # Values > 5 - - # Verify the exact content with styling appears in the output - assert ( - '1-low' - in html_output - ) - assert ( - '2-low' - in html_output - ) - assert ( - '3-mid' in html_output - ) - assert ( - '4-mid' in html_output - ) - assert ( - '6-high' - in html_output - ) - assert ( - '8-high' - in html_output - ) - - # Count occurrences to ensure all cells are properly styled - assert html_output.count("-low") == 2 # Two low values (1, 2) - assert html_output.count("-mid") == 4 # Four mid values (3, 4, 5, 5) - assert html_output.count("-high") == 3 # Three high values (6, 8, 8) - - # Create a custom cell builder that changes background color based on value - def custom_cell_builder(value, row, col, table_id): - # Handle numeric values regardless of their exact type - try: - num_value = int(value) - if num_value > 5: # Values > 5 get green background - return f'{value}' - if num_value < 3: # Values < 3 get light blue background - return f'{value}' - except (ValueError, TypeError): - pass - - # Default styling for other cells - return f'{value}' - - # Set our custom cell builder - formatter = get_formatter() - formatter.set_custom_cell_builder(custom_cell_builder) - - html_output = df._repr_html_() - - # Verify our custom cell styling was applied - assert "background-color: #d3e9f0" in html_output # For values 1,2 def test_html_formatter_custom_header_builder(df, clean_formatter_state): From a6792c9379c677f0a2456d1c886524136b5489de Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 21:19:43 +0800 Subject: [PATCH 09/40] refactor: extract Python formatter retrieval into a separate function --- src/dataframe.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/dataframe.rs b/src/dataframe.rs index 62069461a..98983473a 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -93,10 +93,15 @@ impl Default for FormatterConfig { } } -fn get_formatter_config(py: Python) -> PyResult { +/// Get the Python formatter from the datafusion.html_formatter module +fn get_python_formatter(py: Python) -> PyResult> { let formatter_module = py.import("datafusion.html_formatter")?; let get_formatter = formatter_module.getattr("get_formatter")?; - let formatter = get_formatter.call0()?; + get_formatter.call0() +} + +fn get_formatter_config(py: Python) -> PyResult { + let formatter = get_python_formatter(py)?; // Helper function to extract attributes with fallback to default fn get_attr<'a>( @@ -205,9 +210,7 @@ impl PyDataFrame { let py_schema = self.schema().into_pyobject(py)?; // Get the Python formatter module and call format_html - let formatter_module = py.import("datafusion.html_formatter")?; - let get_formatter = formatter_module.getattr("get_formatter")?; - let formatter = get_formatter.call0()?; + let formatter = get_python_formatter(py)?; // Call format_html method on the formatter let kwargs = pyo3::types::PyDict::new(py); From af678b526d0f6f735ce4f06232a30c54775e95fd Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 21:23:15 +0800 Subject: [PATCH 10/40] Revert "feat: add tests for HTML formatter row display settings and memory limit" This reverts commit e089d7b282e53e587116b11d92760e6d292ec871. --- python/tests/test_dataframe.py | 136 ++++++++++++++++----------------- 1 file changed, 68 insertions(+), 68 deletions(-) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 2a6f7ec5a..464b884db 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -679,9 +679,6 @@ def test_html_formatter_configuration(df, clean_formatter_state): max_width=500, max_height=200, enable_cell_expansion=False, - max_memory_bytes=1024 * 1024, # 1 MB - min_rows_display=15, - repr_rows=5, ) html_output = df._repr_html_() @@ -693,71 +690,6 @@ def test_html_formatter_configuration(df, clean_formatter_state): assert "expandable-container" not in html_output -def test_html_formatter_row_display_settings(clean_formatter_state): - """Test that min_rows_display and repr_rows affect the output.""" - ctx = SessionContext() - - # Create a dataframe with 30 rows - data = list(range(30)) - batch = pa.RecordBatch.from_arrays( - [pa.array(data)], - names=["value"], - ) - df = ctx.create_dataframe([[batch]]) - - # Test with default settings (should use repr_rows) - configure_formatter(repr_rows=7, min_rows_display=20) - html_default = df._repr_html_() - - # Verify we only show repr_rows (7) rows in the output - # by counting the number of value cells - value_cells = re.findall(r"]*>\s*\d+\s*", html_default) - assert len(value_cells) == 7 - assert "... with 23 more rows" in html_default - - # Configure to show all rows since it's below min_rows_display - reset_formatter() - configure_formatter(repr_rows=5, min_rows_display=50) - html_all = df._repr_html_() - - # Verify we show all rows - value_cells = re.findall(r"]*>\s*\d+\s*", html_all) - assert len(value_cells) == 30 - assert "... with" not in html_all - - -def test_html_formatter_memory_limit(clean_formatter_state): - """Test that max_memory_bytes limits the HTML rendering.""" - ctx = SessionContext() - - # Create a large string that will consume substantial memory when rendered - large_string = "x" * 100000 - - # Create a dataframe with 10 rows of large strings - batch = pa.RecordBatch.from_arrays( - [pa.array([large_string] * 10)], - names=["large_value"], - ) - df = ctx.create_dataframe([[batch]]) - - # Set very small memory limit - configure_formatter(max_memory_bytes=1000) # 1KB - - html_limited = df._repr_html_() - - # Verify that memory limit warning is included in the output - assert "Memory usage limit reached" in html_limited - - # Now with larger limit, should display normally - reset_formatter() - configure_formatter(max_memory_bytes=10 * 1024 * 1024) # 10MB - - html_full = df._repr_html_() - - # Verify no memory limit warning - assert "Memory usage limit reached" not in html_full - - def test_html_formatter_custom_style_provider(df, clean_formatter_state): """Test using custom style providers with the HTML formatter.""" @@ -839,6 +771,74 @@ def custom_cell_builder(value, row, col, table_id): r']*>(\d+)-low', html_output ) mid_cells = re.findall( + r']*>(\d+)-mid', html_output + ) + high_cells = re.findall( + r']*>(\d+)-high', html_output + ) + + # Sort the extracted values for consistent comparison + low_cells = sorted(map(int, low_cells)) + mid_cells = sorted(map(int, mid_cells)) + high_cells = sorted(map(int, high_cells)) + + # Verify specific values have the correct styling applied + assert low_cells == [1, 2] # Values < 3 + assert mid_cells == [3, 4, 5, 5] # Values 3-5 + assert high_cells == [6, 8, 8] # Values > 5 + + # Verify the exact content with styling appears in the output + assert ( + '1-low' + in html_output + ) + assert ( + '2-low' + in html_output + ) + assert ( + '3-mid' in html_output + ) + assert ( + '4-mid' in html_output + ) + assert ( + '6-high' + in html_output + ) + assert ( + '8-high' + in html_output + ) + + # Count occurrences to ensure all cells are properly styled + assert html_output.count("-low") == 2 # Two low values (1, 2) + assert html_output.count("-mid") == 4 # Four mid values (3, 4, 5, 5) + assert html_output.count("-high") == 3 # Three high values (6, 8, 8) + + # Create a custom cell builder that changes background color based on value + def custom_cell_builder(value, row, col, table_id): + # Handle numeric values regardless of their exact type + try: + num_value = int(value) + if num_value > 5: # Values > 5 get green background + return f'{value}' + if num_value < 3: # Values < 3 get light blue background + return f'{value}' + except (ValueError, TypeError): + pass + + # Default styling for other cells + return f'{value}' + + # Set our custom cell builder + formatter = get_formatter() + formatter.set_custom_cell_builder(custom_cell_builder) + + html_output = df._repr_html_() + + # Verify our custom cell styling was applied + assert "background-color: #d3e9f0" in html_output # For values 1,2 def test_html_formatter_custom_header_builder(df, clean_formatter_state): From 4090fd2f7378855b045d6bfd1368d088cc9ada75 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Apr 2025 21:26:31 +0800 Subject: [PATCH 11/40] feat: add tests for HTML formatter row and memory limit configurations --- python/tests/test_dataframe.py | 1107 ++++++++++++++++++++++++++++++++ 1 file changed, 1107 insertions(+) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 464b884db..64b53f491 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -690,6 +690,1063 @@ def test_html_formatter_configuration(df, clean_formatter_state): assert "expandable-container" not in html_output +def test_html_formatter_row_memory_limits(clean_formatter_state): + """Test the HTML formatter's row and memory limit parameters.""" + ctx = SessionContext() + + # Create a DataFrame with 50 rows and some wide string data + wide_data = ["x" * 1000] * 50 # 1000 character strings to test memory limits + ids = list(range(50)) + + batch = pa.RecordBatch.from_arrays( + [pa.array(ids), pa.array(wide_data)], + names=["id", "wide_data"], + ) + df = ctx.create_dataframe([[batch]]) + + # Test with custom repr_rows (show only 5 rows in repr) + configure_formatter( + repr_rows=5, + min_rows_display=20, # This should not override repr_rows + max_memory_bytes=2 * 1024 * 1024, # Default 2MB + ) + + html_output = df._repr_html_() + + # Only 5 rows should be rendered (first few rows + last few rows) + # The string "id: 4" should appear (last of first chunk) + # The string "id: 45" should appear (first of last chunk) + row_matches = re.findall(r"]*?>(\d+)", html_output) + assert len(row_matches) <= 10 # Should have at most 10 (5 from top, 5 from bottom) + + # Test with smaller memory limit + configure_formatter( + repr_rows=50, # Try to show all rows + max_memory_bytes=10 * 1000, # Only ~10 rows of our data should fit + ) + + html_output = df._repr_html_() + + # Memory limit should cause truncation despite higher repr_rows + truncation_message = re.search(r"Output truncated|memory limit", html_output, re.IGNORECASE) + assert truncation_message is not None + + # Test with min_rows_display + small_batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array(["a", "b", "c"])], + names=["id", "value"], + ) + small_df = ctx.create_dataframe([[small_batch]]) + + # Set min_rows_display higher than actual rows + configure_formatter( + min_rows_display=10, + repr_rows=5, + ) + + html_output = small_df._repr_html_() + + # All rows should be shown without truncation since fewer than min_rows_display + assert "truncated" not in html_output.lower() + + # All 3 rows should be present + row_count = len(re.findall(r" str: + return ( + "background-color: #f5f5f5; color: #333; padding: 8px; border: " + "1px solid #ddd;" + ) + + def get_header_style(self) -> str: + return ( + "background-color: #4285f4; color: white; font-weight: bold; " + "padding: 10px; border: 1px solid #3367d6;" + ) + + # Configure with custom style provider + configure_formatter(style_provider=CustomStyleProvider()) + + html_output = df._repr_html_() + + # Verify our custom styles were applied + assert "background-color: #4285f4" in html_output + assert "color: white" in html_output + assert "background-color: #f5f5f5" in html_output + + +def test_html_formatter_type_formatters(df, clean_formatter_state): + """Test registering custom type formatters for specific data types.""" + + # Get current formatter and register custom formatters + formatter = get_formatter() + + # Format integers with color based on value + # Using int as the type for the formatter will work since we convert + # Arrow scalar values to Python native types in _get_cell_value + def format_int(value): + return f' 2 else "blue"}">{value}' + + formatter.register_formatter(int, format_int) + + html_output = df._repr_html_() + + # Our test dataframe has values 1,2,3 so we should see: + assert '1' in html_output + + +def test_html_formatter_custom_cell_builder(df, clean_formatter_state): + """Test using a custom cell builder function.""" + + # Create a custom cell builder with distinct styling for different value ranges + def custom_cell_builder(value, row, col, table_id): + try: + num_value = int(value) + if num_value > 5: # Values > 5 get green background with indicator + return ( + '{value}-high' + ) + if num_value < 3: # Values < 3 get blue background with indicator + return ( + '{value}-low' + ) + except (ValueError, TypeError): + pass + + # Default styling for other cells (3, 4, 5) + return f'{value}-mid' + + # Set our custom cell builder + formatter = get_formatter() + formatter.set_custom_cell_builder(custom_cell_builder) + + html_output = df._repr_html_() + + # Extract cells with specific styling using regex + low_cells = re.findall( + r']*>(\d+)-low', html_output + ) + mid_cells = re.findall( + r']*>(\d+)-mid', html_output + ) + high_cells = re.findall( + r']*>(\d+)-high', html_output + ) + + # Sort the extracted values for consistent comparison + low_cells = sorted(map(int, low_cells)) + mid_cells = sorted(map(int, mid_cells)) + high_cells = sorted(map(int, high_cells)) + + # Verify specific values have the correct styling applied + assert low_cells == [1, 2] # Values < 3 + assert mid_cells == [3, 4, 5, 5] # Values 3-5 + assert high_cells == [6, 8, 8] # Values > 5 + + # Verify the exact content with styling appears in the output + assert ( + '1-low' + in html_output + ) + assert ( + '2-low' + in html_output + ) + assert ( + '3-mid' in html_output + ) + assert ( + '4-mid' in html_output + ) + assert ( + '6-high' + in html_output + ) + assert ( + '8-high' + in html_output + ) + + # Count occurrences to ensure all cells are properly styled + assert html_output.count("-low") == 2 # Two low values (1, 2) + assert html_output.count("-mid") == 4 # Four mid values (3, 4, 5, 5) + assert html_output.count("-high") == 3 # Three high values (6, 8, 8) + + # Create a custom cell builder that changes background color based on value + def custom_cell_builder(value, row, col, table_id): + # Handle numeric values regardless of their exact type + try: + num_value = int(value) + if num_value > 5: # Values > 5 get green background + return f'{value}' + if num_value < 3: # Values < 3 get light blue background + return f'{value}' + except (ValueError, TypeError): + pass + + # Default styling for other cells + return f'{value}' + + # Set our custom cell builder + formatter = get_formatter() + formatter.set_custom_cell_builder(custom_cell_builder) + + html_output = df._repr_html_() + + # Verify our custom cell styling was applied + assert "background-color: #d3e9f0" in html_output # For values 1,2 + + +def test_html_formatter_custom_header_builder(df, clean_formatter_state): + """Test using a custom header builder function.""" + + # Create a custom header builder with tooltips + def custom_header_builder(field): + tooltips = { + "a": "Primary key column", + "b": "Secondary values", + "c": "Additional data", + } + tooltip = tooltips.get(field.name, "") + return ( + f'{field.name}' + ) + + # Set our custom header builder + formatter = get_formatter() + formatter.set_custom_header_builder(custom_header_builder) + + html_output = df._repr_html_() + + # Verify our custom headers were applied + assert 'title="Primary key column"' in html_output + assert 'title="Secondary values"' in html_output + assert "background-color: #333; color: white" in html_output + + +def test_html_formatter_complex_customization(df, clean_formatter_state): + """Test combining multiple customization options together.""" + + # Create a dark mode style provider + class DarkModeStyleProvider: + def get_cell_style(self) -> str: + return ( + "background-color: #222; color: #eee; " + "padding: 8px; border: 1px solid #444;" + ) + + def get_header_style(self) -> str: + return ( + "background-color: #111; color: #fff; padding: 10px; " + "border: 1px solid #333;" + ) + + # Configure with dark mode style + configure_formatter( + max_cell_length=10, + style_provider=DarkModeStyleProvider(), + custom_css=""" + .datafusion-table { + font-family: monospace; + border-collapse: collapse; + } + .datafusion-table tr:hover td { + background-color: #444 !important; + } + """, + ) + + # Add type formatters for special formatting - now working with native int values + formatter = get_formatter() + formatter.register_formatter( + int, + lambda n: f'{n}', + ) + + html_output = df._repr_html_() + + # Verify our customizations were applied + assert "background-color: #222" in html_output + assert "background-color: #111" in html_output + assert ".datafusion-table" in html_output + assert "color: #5af" in html_output # Even numbers + + +def test_get_dataframe(tmp_path): + ctx = SessionContext() + + path = tmp_path / "test.csv" + table = pa.Table.from_arrays( + [ + [1, 2, 3, 4], + ["a", "b", "c", "d"], + [1.1, 2.2, 3.3, 4.4], + ], + names=["int", "str", "float"], + ) + write_csv(table, path) + + ctx.register_csv("csv", path) + + df = ctx.table("csv") + assert isinstance(df, DataFrame) + + +def test_struct_select(struct_df): + df = struct_df.select( + column("a")["c"] + column("b"), + column("a")["c"] - column("b"), + ) + + # execute and collect the first (and only) batch + result = df.collect()[0] + + assert result.column(0) == pa.array([5, 7, 9]) + assert result.column(1) == pa.array([-3, -3, -3]) + + +def test_explain(df): + df = df.select( + column("a") + column("b"), + column("a") - column("b"), + ) + df.explain() + + +def test_logical_plan(aggregate_df): + plan = aggregate_df.logical_plan() + + expected = "Projection: test.c1, sum(test.c2)" + + assert expected == plan.display() + + expected = ( + "Projection: test.c1, sum(test.c2)\n" + " Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n" + " TableScan: test" + ) + + assert expected == plan.display_indent() + + +def test_optimized_logical_plan(aggregate_df): + plan = aggregate_df.optimized_logical_plan() + + expected = "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]" + + assert expected == plan.display() + + expected = ( + "Aggregate: groupBy=[[test.c1]], aggr=[[sum(test.c2)]]\n" + " TableScan: test projection=[c1, c2]" + ) + + assert expected == plan.display_indent() + + +def test_execution_plan(aggregate_df): + plan = aggregate_df.execution_plan() + + expected = ( + "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[sum(test.c2)]\n" + ) + + assert expected == plan.display() + + # Check the number of partitions is as expected. + assert isinstance(plan.partition_count, int) + + expected = ( + "ProjectionExec: expr=[c1@0 as c1, SUM(test.c2)@1 as SUM(test.c2)]\n" + " Aggregate: groupBy=[[test.c1]], aggr=[[SUM(test.c2)]]\n" + " TableScan: test projection=[c1, c2]" + ) + + indent = plan.display_indent() + + # indent plan will be different for everyone due to absolute path + # to filename, so we just check for some expected content + assert "AggregateExec:" in indent + assert "CoalesceBatchesExec:" in indent + assert "RepartitionExec:" in indent + assert "DataSourceExec:" in indent + assert "file_type=csv" in indent + + ctx = SessionContext() + rows_returned = 0 + for idx in range(plan.partition_count): + stream = ctx.execute(plan, idx) + try: + batch = stream.next() + assert batch is not None + rows_returned += len(batch.to_pyarrow()[0]) + except StopIteration: + # This is one of the partitions with no values + pass + with pytest.raises(StopIteration): + stream.next() + + assert rows_returned == 5 + + +@pytest.mark.asyncio +async def test_async_iteration_of_df(aggregate_df): + rows_returned = 0 + async for batch in aggregate_df.execute_stream(): + assert batch is not None + rows_returned += len(batch.to_pyarrow()[0]) + + assert rows_returned == 5 + + +def test_repartition(df): + df.repartition(2) + + +def test_repartition_by_hash(df): + df.repartition_by_hash(column("a"), num=2) + + +def test_intersect(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3]), pa.array([6])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort(column("a")) + + df_a_i_b = df_a.intersect(df_b).sort(column("a")) + + assert df_c.collect() == df_a_i_b.collect() + + +def test_except_all(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([4, 5])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort(column("a")) + + df_a_e_b = df_a.except_all(df_b).sort(column("a")) + + assert df_c.collect() == df_a_e_b.collect() + + +def test_collect_partitioned(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + + assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned() + + +def test_union(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort(column("a")) + + df_a_u_b = df_a.union(df_b).sort(column("a")) + + assert df_c.collect() == df_a_u_b.collect() + + +def test_union_distinct(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort(column("a")) + + df_a_u_b = df_a.union(df_b, distinct=True).sort(column("a")) + + assert df_c.collect() == df_a_u_b.collect() + assert df_c.collect() == df_a_u_b.collect() + + +def test_cache(df): + assert df.cache().collect() == df.collect() + + +def test_count(df): + # Get number of rows + assert df.count() == 3 + + +def test_to_pandas(df): + # Skip test if pandas is not installed + pd = pytest.importorskip("pandas") + + # Convert datafusion dataframe to pandas dataframe + pandas_df = df.to_pandas() + assert isinstance(pandas_df, pd.DataFrame) + assert pandas_df.shape == (3, 3) + assert set(pandas_df.columns) == {"a", "b", "c"} + + +def test_empty_to_pandas(df): + # Skip test if pandas is not installed + pd = pytest.importorskip("pandas") + + # Convert empty datafusion dataframe to pandas dataframe + pandas_df = df.limit(0).to_pandas() + assert isinstance(pandas_df, pd.DataFrame) + assert pandas_df.shape == (0, 3) + assert set(pandas_df.columns) == {"a", "b", "c"} + + +def test_to_polars(df): + # Skip test if polars is not installed + pl = pytest.importorskip("polars") + + # Convert datafusion dataframe to polars dataframe + polars_df = df.to_polars() + assert isinstance(polars_df, pl.DataFrame) + assert polars_df.shape == (3, 3) + assert set(polars_df.columns) == {"a", "b", "c"} + + +def test_empty_to_polars(df): + # Skip test if polars is not installed + pl = pytest.importorskip("polars") + + # Convert empty datafusion dataframe to polars dataframe + polars_df = df.limit(0).to_polars() + assert isinstance(polars_df, pl.DataFrame) + assert polars_df.shape == (0, 3) + assert set(polars_df.columns) == {"a", "b", "c"} + + +def test_to_arrow_table(df): + # Convert datafusion dataframe to pyarrow Table + pyarrow_table = df.to_arrow_table() + assert isinstance(pyarrow_table, pa.Table) + assert pyarrow_table.shape == (3, 3) + assert set(pyarrow_table.column_names) == {"a", "b", "c"} + + +def test_execute_stream(df): + stream = df.execute_stream() + assert all(batch is not None for batch in stream) + assert not list(stream) # after one iteration the generator must be exhausted + + +@pytest.mark.asyncio +async def test_execute_stream_async(df): + stream = df.execute_stream() + batches = [batch async for batch in stream] + + assert all(batch is not None for batch in batches) + + # After consuming all batches, the stream should be exhausted + remaining_batches = [batch async for batch in stream] + assert not remaining_batches + + +@pytest.mark.parametrize("schema", [True, False]) +def test_execute_stream_to_arrow_table(df, schema): + stream = df.execute_stream() + + if schema: + pyarrow_table = pa.Table.from_batches( + (batch.to_pyarrow() for batch in stream), schema=df.schema() + ) + else: + pyarrow_table = pa.Table.from_batches(batch.to_pyarrow() for batch in stream) + + assert isinstance(pyarrow_table, pa.Table) + assert pyarrow_table.shape == (3, 3) + assert set(pyarrow_table.column_names) == {"a", "b", "c"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("schema", [True, False]) +async def test_execute_stream_to_arrow_table_async(df, schema): + stream = df.execute_stream() + + if schema: + pyarrow_table = pa.Table.from_batches( + [batch.to_pyarrow() async for batch in stream], schema=df.schema() + ) + else: + pyarrow_table = pa.Table.from_batches( + [batch.to_pyarrow() async for batch in stream] + ) + + assert isinstance(pyarrow_table, pa.Table) + assert pyarrow_table.shape == (3, 3) + assert set(pyarrow_table.column_names) == {"a", "b", "c"} + + +def test_execute_stream_partitioned(df): + streams = df.execute_stream_partitioned() + assert all(batch is not None for stream in streams for batch in stream) + assert all( + not list(stream) for stream in streams + ) # after one iteration all generators must be exhausted + + +@pytest.mark.asyncio +async def test_execute_stream_partitioned_async(df): + streams = df.execute_stream_partitioned() + + for stream in streams: + batches = [batch async for batch in stream] + assert all(batch is not None for batch in batches) + + # Ensure the stream is exhausted after iteration + remaining_batches = [batch async for batch in stream] + assert not remaining_batches + + +def test_empty_to_arrow_table(df): + # Convert empty datafusion dataframe to pyarrow Table + pyarrow_table = df.limit(0).to_arrow_table() + assert isinstance(pyarrow_table, pa.Table) + assert pyarrow_table.shape == (0, 3) + assert set(pyarrow_table.column_names) == {"a", "b", "c"} + + +def test_to_pylist(df): + # Convert datafusion dataframe to Python list + pylist = df.to_pylist() + assert isinstance(pylist, list) + assert pylist == [ + {"a": 1, "b": 4, "c": 8}, + {"a": 2, "b": 5, "c": 5}, + {"a": 3, "b": 6, "c": 8}, + ] + + +def test_to_pydict(df): + # Convert datafusion dataframe to Python dictionary + pydict = df.to_pydict() + assert isinstance(pydict, dict) + assert pydict == {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8]} + + +def test_describe(df): + # Calculate statistics + df = df.describe() + + # Collect the result + result = df.to_pydict() + + assert result == { + "describe": [ + "count", + "null_count", + "mean", + "std", + "min", + "max", + "median", + ], + "a": [3.0, 0.0, 2.0, 1.0, 1.0, 3.0, 2.0], + "b": [3.0, 0.0, 5.0, 1.0, 4.0, 6.0, 5.0], + "c": [3.0, 0.0, 7.0, 1.7320508075688772, 5.0, 8.0, 8.0], + } + + +@pytest.mark.parametrize("path_to_str", [True, False]) +def test_write_csv(ctx, df, tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path + + df.write_csv(path, with_header=True) + + ctx.register_csv("csv", path) + result = ctx.table("csv").to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize("path_to_str", [True, False]) +def test_write_json(ctx, df, tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path + + df.write_json(path) + + ctx.register_json("json", path) + result = ctx.table("json").to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize("path_to_str", [True, False]) +def test_write_parquet(df, tmp_path, path_to_str): + path = str(tmp_path) if path_to_str else tmp_path + + df.write_parquet(str(path)) + result = pq.read_table(str(path)).to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize( + ("compression", "compression_level"), + [("gzip", 6), ("brotli", 7), ("zstd", 15)], +) +def test_write_compressed_parquet(df, tmp_path, compression, compression_level): + path = tmp_path + + df.write_parquet( + str(path), compression=compression, compression_level=compression_level + ) + + # test that the actual compression scheme is the one written + for _root, _dirs, files in os.walk(path): + for file in files: + if file.endswith(".parquet"): + metadata = pq.ParquetFile(tmp_path / file).metadata.to_dict() + for row_group in metadata["row_groups"]: + for columns in row_group["columns"]: + assert columns["compression"].lower() == compression + + result = pq.read_table(str(path)).to_pydict() + expected = df.to_pydict() + + assert result == expected + + +@pytest.mark.parametrize( + ("compression", "compression_level"), + [("gzip", 12), ("brotli", 15), ("zstd", 23), ("wrong", 12)], +) +def test_write_compressed_parquet_wrong_compression_level( + df, tmp_path, compression, compression_level +): + path = tmp_path + + with pytest.raises(ValueError): + df.write_parquet( + str(path), + compression=compression, + compression_level=compression_level, + ) + + +@pytest.mark.parametrize("compression", ["wrong"]) +def test_write_compressed_parquet_invalid_compression(df, tmp_path, compression): + path = tmp_path + + with pytest.raises(ValueError): + df.write_parquet(str(path), compression=compression) + + +# not testing lzo because it it not implemented yet +# https://github.com/apache/arrow-rs/issues/6970 +@pytest.mark.parametrize("compression", ["zstd", "brotli", "gzip"]) +def test_write_compressed_parquet_default_compression_level(df, tmp_path, compression): + # Test write_parquet with zstd, brotli, gzip default compression level, + # ie don't specify compression level + # should complete without error + path = tmp_path + + df.write_parquet(str(path), compression=compression) + + +def test_dataframe_export(df) -> None: + # Guarantees that we have the canonical implementation + # reading our dataframe export + table = pa.table(df) + assert table.num_columns == 3 + assert table.num_rows == 3 + + desired_schema = pa.schema([("a", pa.int64())]) + + # Verify we can request a schema + table = pa.table(df, schema=desired_schema) + assert table.num_columns == 1 + assert table.num_rows == 3 + + # Expect a table of nulls if the schema don't overlap + desired_schema = pa.schema([("g", pa.string())]) + table = pa.table(df, schema=desired_schema) + assert table.num_columns == 1 + assert table.num_rows == 3 + for i in range(3): + assert table[0][i].as_py() is None + + # Expect an error when we cannot convert schema + desired_schema = pa.schema([("a", pa.float32())]) + failed_convert = False + try: + table = pa.table(df, schema=desired_schema) + except Exception: + failed_convert = True + assert failed_convert + + # Expect an error when we have a not set non-nullable + desired_schema = pa.schema([("g", pa.string(), False)]) + failed_convert = False + try: + table = pa.table(df, schema=desired_schema) + except Exception: + failed_convert = True + assert failed_convert + + +def test_dataframe_transform(df): + def add_string_col(df_internal) -> DataFrame: + return df_internal.with_column("string_col", literal("string data")) + + def add_with_parameter(df_internal, value: Any) -> DataFrame: + return df_internal.with_column("new_col", literal(value)) + + df = df.transform(add_string_col).transform(add_with_parameter, 3) + + result = df.to_pydict() + + assert result["a"] == [1, 2, 3] + assert result["string_col"] == ["string data" for _i in range(3)] + assert result["new_col"] == [3 for _i in range(3)] + + +def test_dataframe_repr_html_structure(df) -> None: + """Test that DataFrame._repr_html_ produces expected HTML output structure.""" + import re + + output = df._repr_html_() + + # Since we've added a fair bit of processing to the html output, lets just verify + # the values we are expecting in the table exist. Use regex and ignore everything + # between the and . We also don't want the closing > on the + # td and th segments because that is where the formatting data is written. + + headers = ["a", "b", "c"] + headers = [f"{v}" for v in headers] + header_pattern = "(.*?)".join(headers) + header_matches = re.findall(header_pattern, output, re.DOTALL) + assert len(header_matches) == 1 + + # Update the pattern to handle values that may be wrapped in spans + body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]] + + body_lines = [ + f"(?:]*?>)?{v}(?:)?" + for inner in body_data + for v in inner + ] + body_pattern = "(.*?)".join(body_lines) + + body_matches = re.findall(body_pattern, output, re.DOTALL) + + assert len(body_matches) == 1, "Expected pattern of values not found in HTML output" + + +def test_dataframe_repr_html_values(df): + """Test that DataFrame._repr_html_ contains the expected data values.""" + html = df._repr_html_() + assert html is not None + + # Create a more flexible pattern that handles values being wrapped in spans + # This pattern will match the sequence of values 1,4,8,2,5,5,3,6,8 regardless + # of formatting + pattern = re.compile( + r"]*?>(?:]*?>)?1(?:)?.*?" + r"]*?>(?:]*?>)?4(?:)?.*?" + r"]*?>(?:]*?>)?8(?:)?.*?" + r"]*?>(?:]*?>)?2(?:)?.*?" + r"]*?>(?:]*?>)?5(?:)?.*?" + r"]*?>(?:]*?>)?5(?:)?.*?" + r"]*?>(?:]*?>)?3(?:)?.*?" + r"]*?>(?:]*?>)?6(?:)?.*?" + r"]*?>(?:]*?>)?8(?:)?", + re.DOTALL, + ) + + # Print debug info if the test fails + matches = re.findall(pattern, html) + if not matches: + print(f"HTML output snippet: {html[:500]}...") # noqa: T201 + + assert len(matches) > 0, "Expected pattern of values not found in HTML output" + + +def test_html_formatter_shared_styles(df, clean_formatter_state): + """Test that shared styles work correctly across multiple tables.""" + + # First, ensure we're using shared styles + configure_formatter(use_shared_styles=True) + + # Get HTML output for first table - should include styles + html_first = df._repr_html_() + + # Verify styles are included in first render + assert "