8000 feat: add fill_null/nan (#919) · vectorlink-ai/datafusion-python@0905f5f · GitHub
[go: up one dir, main page]

Skip to content

Commit 0905f5f

Browse files
authored
feat: add fill_null/nan (apache#919)
1 parent 494b89a commit 0905f5f

File tree

4 files changed

+54
-3
lines changed

4 files changed

+54
-3
lines changed

python/datafusion/expr.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,18 @@ def is_not_null(self) -> Expr:
406406
"""Returns ``True`` if this expression is not null."""
407407
return Expr(self.expr.is_not_null())
408408

409+
def fill_nan(self, value: Any | Expr | None = None) -> Expr:
410+
"""Fill NaN values with a provided value."""
411+
if not isinstance(value, Expr):
412+
value = Expr.literal(value)
413+
return Expr(functions_internal.nanvl(self.expr, value.expr))
414+
415+
def fill_null(self, value: Any | Expr | None = None) -> Expr:
416+
"""Fill NULL values with a provided value."""
417+
if not isinstance(value, Expr):
418+
value = Expr.literal(value)
419+
return Expr(functions_internal.nvl(self.expr, value.expr))
420+
409421
_to_pyarrow_types = {
410422
float: pa.float64(),
411423
int: pa.int64(),

python/datafusion/functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@
186186
"min",
187187
"named_struct",
188188
"nanvl",
189+
"nvl",
189190
"now",
190191
"nth_value",
191192
"nullif",
@@ -673,6 +674,11 @@ def nanvl(x: Expr, y: Expr) -> Expr:
673674
return Expr(f.nanvl(x.expr, y.expr))
674675

675676

677+
def nvl(x: Expr, y: Expr) -> Expr:
678+
"""Returns ``x`` if ``x`` is not ``NULL``. Otherwise returns ``y``."""
679+
return Expr(f.nvl(x.expr, y.expr))
680+
681+
676682
def octet_length(arg: Expr) -> Expr:
677683
"""Returns the number of bytes of a string."""
678684
return Expr(f.octet_length(arg.expr))

python/tests/test_expr.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import pyarrow
18+
import pyarrow as pa
1919
import pytest
2020
from datafusion import SessionContext, col
2121
from datafusion.expr import (
@@ -125,8 +125,8 @@ def test_sort(test_ctx):
125125
def test_relational_expr(test_ctx):
126126
ctx = SessionContext()
127127

128-
batch = pyarrow.RecordBatch.from_arrays(
129-
[pyarrow.array([1, 2, 3]), pyarrow.array(["alpha", "beta", "gamma"])],
128+
batch = pa.RecordBatch.from_arrays(
129+
[pa.array([1, 2, 3]), pa.array(["alpha", "beta", "gamma"])],
130130
names=["a", "b"],
131131
)
132132
df = ctx.create_dataframe([[batch]], name="batch_array")
@@ -216,3 +216,30 @@ def test_display_name_deprecation():
216216
# returns appropriate result
217217
assert name == expr.schema_name()
218218
assert name == "foo"
219+
220+
221+
@pytest.fixture
222+
def df():
223+
ctx = SessionContext()
224+
225+
# create a RecordBatch and a new DataFrame from it
226+
batch = pa.RecordBatch.from_arrays(
227+
[pa.array([1, 2, None]), pa.array([4, None, 6]), pa.array([None, None, 8])],
228+
names=["a", "b", "c"],
229+
)
230+
231+
return ctx.from_arrow(batch)
232+
233+
234+
def test_fill_null(df):
235+
df = df.select(
236+
col("a").fill_null(100).alias("a"),
237+
col("b").fill_null(25).alias("b"),
238+
col("c").fill_null(1234).alias("c"),
239+
)
240+
df.show()
241+
result = df.collect()[0]
242+
243+
assert result.column(0) == pa.array([1, 2, 100])
244+
assert result.column(1) == pa.array([4, 25, 6])
245+
assert result.column(2) == pa.array([1234, 1234, 8])

src/functions.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,11 @@ expr_fn!(
490490
x y,
491491
"Returns x if x is not NaN otherwise returns y."
492492
);
493+
expr_fn!(
494+
nvl,
495+
x y,
496+
"Returns x if x is not NULL otherwise returns y."
497+
);
493498
expr_fn!(nullif, arg_1 arg_2);
494499
expr_fn!(octet_length, args, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces.");
495500
expr_fn_vec!(overlay);
@@ -913,6 62F2 +918,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
913918
m.add_wrapped(wrap_pyfunction!(min))?;
914919
m.add_wrapped(wrap_pyfunction!(named_struct))?;
915920
m.add_wrapped(wrap_pyfunction!(nanvl))?;
921+
m.add_wrapped(wrap_pyfunction!(nvl))?;
916922
m.add_wrapped(wrap_pyfunction!(now))?;
917923
m.add_wrapped(wrap_pyfunction!(nullif))?;
918924
m.add_wrapped(wrap_pyfunction!(octet_length))?;

0 commit comments

Comments
 (0)
0