8000 Add truncate option to data_color (#673) · posit-dev/great-tables@d163c3d · GitHub
[go: up one dir, main page]

Skip to content

Commit d163c3d

Browse files
mahdibaghbanzadehmachowrich-iannone
authored
Add truncate option to data_color (#673)
* added truncate argument to data_color * test module for data_color truncate updated * refactor: reduce number of logical branches --------- Co-authored-by: Michael Chow <mc_al_github@fastmail.com> Co-authored-by: Richard Iannone <riannone@me.com>
1 parent 95c5cb8 commit d163c3d

File tree

3 files changed

+63
-9
lines changed

3 files changed

+63
-9
lines changed

great_tables/_data_color/base.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def data_color(
2929
alpha: int | float | None = None,
3030
reverse: bool = False,
3131
autocolor_text: bool = True,
32+
truncate: bool = False,
3233
) -> GTSelf:
3334
"""
3435
Perform data cell colorization.
@@ -74,6 +75,11 @@ def data_color(
7475
autocolor_text
7576
Whether or not to automatically color the text of the data values. If `True`, then the text
7677
will be colored according to the background color of the cell.
78+
truncate
79+
If `True`, then any values that fall outside of the domain will be truncated to the
80+
minimum or maximum value of the domain (will have the same color). If `False`, then any
81+
values that fall outside of the domain will be set to `NaN` and will follow the `na_color=`
82+
color.
7783
7884
Returns
7985
-------
@@ -238,7 +244,9 @@ def data_color(
238244
domain = _get_domain_numeric(df=data_table, vals=column_vals)
239245

240246
# Rescale only the non-NA values in `column_vals` to the range [0, 1]
241-
scaled_vals = _rescale_numeric(df=data_table, vals=column_vals, domain=domain)
247+
scaled_vals = _rescale_numeric(
248+
df=data_table, vals=column_vals, domain=domain, truncate=truncate
249+
)
242250

243251
elif all(isinstance(x, str) for x in filtered_column_vals):
244252
# If `domain` is not provided, then infer it from the data values
@@ -569,7 +577,7 @@ def _expand_short_hex(hex_color: str) -> str:
569577

570578

571579
def _rescale_numeric(
572-
df: DataFrameLike, vals: list[int | float], domain: list[float]
580+
df: DataFrameLike, vals: list[int | float], domain: list[float], truncate: bool = False
573581
) -> list[float]:
574582
"""
575583
Rescale numeric values
@@ -588,10 +596,19 @@ def _rescale_numeric(
588596
scaled_vals = [0.0 if not is_na(df, x) else x for x in vals]
589597
else:
590598
# Rescale the values in `vals` to the range [0, 1], pass through NA values
591-
scaled_vals = [(x - domain_min) / domain_range if not is_na(df, x) else x for x in vals]
599+
filled = [np.nan if is_na(df, x) else x for x in vals]
600+
scaled = [(x - domain_min) / domain_range for x in filled]
601+
602+
if truncate:
603+
# values outside domain set to 0 or 1
604+
min_val = 0.0
605+
max_val = 1.0
606+
else:
607+
# values outside domain set to missing
608+
min_val = np.nan
609+
max_val = np.nan
592610

593-
# Add NA values to any values in `scaled_vals` that are not in the [0, 1] range
594-
scaled_vals = [x if not is_na(df, x) and (x >= 0 and x <= 1) else np.nan for x in scaled_vals]
611+
scaled_vals = [min_val if x < 0 else max_val if x > 1 else x for x in scaled]
595612

596613
return scaled_vals
597614

tests/data_color/test_data_color.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pandas as pd
55
import polars as pl
66
import pytest
7+
78
from great_tables import GT, style
89
from great_tables._gt_data import CellStyle, StyleInfo
910
from great_tables._tbl_data import DataFrameLike
@@ -12,7 +13,10 @@
1213

1314
T_CellStyle = TypeVar("T_CellStyle", bound=CellStyle)
1415

15-
params_frames = [pytest.param(pd.DataFrame, id="pandas"), pytest.param(pl.DataFrame, id="polars")]
16+
params_frames = [
17+
pytest.param(pd.DataFrame, id="pandas"),
18+
pytest.param(pl.DataFrame, id="polars"),
19+
]
1620

1721

1822
@pytest.fixture(params=params_frames, scope="function")
@@ -111,7 +115,11 @@ def test_data_color_domain_na_color_snap(snapshot: str, df: DataFrameLike):
111115
def test_data_color_domain_na_color_reverse_snap(snapshot: str, df: DataFrameLike):
112116
"""`data_color` works with `domain`, `na_color`, and `reverse`."""
113117
gt = GT(df).data_color(
114-
columns="currency", palette=["red", "green"], domain=[0, 50], na_color="blue", reverse=True
118+
columns="currency",
119+
palette=["red", "green"],
120+
domain=[0, 50],
121+
na_color="blue",
122+
reverse=True,
115123
)
116124

117125
assert_rendered_body(snapshot, gt)
@@ -321,3 +329,20 @@ def test_all_values_have_zero_range_domain_pl(snapshot: str):
321329
new_gt = GT(df).data_color("x", palette=["green", "blue"], domain=[0, 0])
322330

323331
assert_rendered_body(snapshot, new_gt)
332+
333+
334+
# test for data_color with truncate=True
335+
def test_data_color_truncate(df: DataFrameLike):
336+
new_gt = GT(df).data_color(
337+
columns=["num", "currency"],
338+
domain=[10, 40],
339+
palette=["#654321", "white", "#123456"],
340+
truncate=True,
341+
)
342+
343+
# check if all cells are colored
344+
assert len(new_gt._styles) == 8
345+
# check if the last cell (out of range of domain) is colored with the last color in the palette
346+
assert get_first_style(new_gt._styles[-1], style.fill).color == "#123456"
347+
# check if the first cell (out of range of domain) is colored with the first color in the palette
348+
assert get_first_style(new_gt._styles[0], style.fill).color == "#654321"

tests/data_color/test_data_color_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@
2323
_srgb,
2424
)
2525
from great_tables._data_color.palettes import GradientPalette
26+
from great_tables._tbl_data import is_na, Agnostic
27+
28+
29+
def assert_equal_with_na(x: list, y: list):
30+
"""Assert two lists are equal, evaluating all NAs as equivalent
31+
32+
Note that some cases like [np.nan] == [np.nan] will be True (since it checks id), but this
33+
function handles cases that trigger equality checks (since np.nan == np.nan is False).
34+
"""
35+
assert len(x) == len(y)
36+
for ii in range(len(x)):
37+
assert (is_na(Agnostic(), x[ii]) and is_na(Agnostic(), y[ii])) or (x[ii] == y[ii])
2638

2739

2840
def test_ideal_fgnd_color_dark_contrast():
@@ -468,15 +480,15 @@ def test_rescale_numeric():
468480
domain = [1, 5]
469481
expected_result = [np.nan, np.nan]
470482
result = _rescale_numeric(df, vals, domain)
471-
assert result == expected_result
483+
assert_equal_with_na(result, expected_result)
472484

473485
# Test case 3: Rescale values with NA values
474486
df = pd.DataFrame({"col": [1, 2, np.nan, 4, 5]})
475487
vals = [2, np.nan, 4]
476488
domain = [1, 5]
477489
expected_result = [0.25, np.nan, 0.75]
478490
result = _rescale_numeric(df, vals, domain)
479-
assert result == expected_result
491+
assert_equal_with_na(result, expected_result)
480492

481493

482494
def test_get_domain_numeric():

0 commit comments

Comments
 (0)
0