8000 feat: add missing scalar math functions (#465) · llama90/arrow-datafusion-python@bc62aaf · GitHub
[go: up one dir, main page]

Skip to content

Commit bc62aaf

Browse files
authored
feat: add missing scalar math functions (apache#465)
1 parent e24dc75 commit bc62aaf

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

datafusion/tests/test_functions.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import math
1718

1819
import numpy as np
1920
import pyarrow as pa
@@ -85,12 +86,15 @@ def test_math_functions():
8586
ctx = SessionContext()
8687
# create a RecordBatch and a new DataFrame from it
8788
batch = pa.RecordBatch.from_arrays(
88-
[pa.array([0.1, -0.7, 0.55])], names=["value"]
89+
[pa.array([0.1, -0.7, 0.55]), pa.array([float("nan"), 0, 2.0])],
90+
names=["value", "na_value"],
8991
)
9092
df = ctx.create_dataframe([[batch]])
9193

9294
values = np.array([0.1, -0.7, 0.55])
95+
na_values = np.array([np.nan, 0, 2.0])
9396
col_v = column("value")
97+
col_nav = column("na_value")
9498
df = df.select(
9599
f.abs(col_v),
96100
f.sin(col_v),
@@ -113,6 +117,20 @@ def test_math_functions():
113117
f.sqrt(col_v),
114118
f.signum(col_v),
115119
f.trunc(col_v),
120+
f.asinh(col_v),
121+
f.acosh(col_v),
122+
f.atanh(col_v),
123+
f.cbrt(col_v),
124+
f.cosh(col_v),
125+
f.degrees(col_v),
126+
f.gcd(literal(9), literal(3)),
127+
f.lcm(literal(6), literal(4)),
128+
f.nanvl(col_nav, literal(5)),
129+
f.pi(),
130+
f.radians(col_v),
131+
f.sinh(col_v),
132+
f.tanh(col_v),
133+
f.factorial(literal(6)),
116134
)
117135
batches = df.collect()
118136
assert len(batches) == 1
@@ -151,6 +169,22 @@ def test_math_functions():
151169
np.testing.assert_array_almost_equal(result.column(18), np.sqrt(values))
152170
np.testing.assert_array_almost_equal(result.column(19), np.sign(values))
153171
np.testing.assert_array_almost_equal(result.column(20), np.trunc(values))
172+
np.testing.assert_array_almost_equal(result.column(21), np.arcsinh(values))
173+
np.testing.assert_array_almost_equal(result.column(22), np.arccosh(values))
174+
np.testing.assert_array_almost_equal(result.column(23), np.arctanh(values))
175+
np.testing.assert_array_almost_equal(result.column(24), np.cbrt(values))
176+
np.testing.assert_array_almost_equal(result.column(25), np.cosh(values))
177+
np.testing.assert_array_almost_equal(result.column(26), np.degrees(values))
178+
np.testing.assert_array_almost_equal(result.column(27), np.gcd(9, 3))
179+
np.testing.assert_array_almost_equal(result.column(28), np.lcm(6, 4))
180+
np.testing.assert_array_almost_equal(
181+
result.column(29), np.where(np.isnan(na_values), 5, na_values)
182+
)
183+
np.testing.assert_array_almost_equal(result.column(30), np.pi)
184+
np.testing.assert_array_almost_equal(result.column(31), np.radians(values))
185+
np.testing.assert_array_almost_equal(result.column(32), np.sinh(values))
186+
np.testing.assert_array_almost_equal(result.column(33), np.tanh(values))
187+
np.testing.assert_array_almost_equal(result.column(34), math.factorial(6))
154188

155189

156190
def test_string_functions(df):

src/functions.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,20 @@ macro_rules! aggregate_function {
198198

199199
scalar_function!(abs, Abs);
200200
scalar_function!(acos, Acos);
201+
scalar_function!(acosh, Acosh);
201202
scalar_function!(ascii, Ascii, "Returns the numeric code of the first character of the argument. In UTF8 encoding, returns the Unicode code point of the character. In other multibyte encodings, the argument must be an ASCII character.");
202203
scalar_function!(asin, Asin);
204+
scalar_function!(asinh, Asinh);
203205
scalar_function!(atan, Atan);
206+
scalar_function!(atanh, Atanh);
204207
scalar_function!(atan2, Atan2);
205208
scalar_function!(
206209
bit_length,
207210
BitLength,
208211
"Returns number of bits in the string (8 times the octet_length)."
209212
);
210213
scalar_function!(btrim, Btrim, "Removes the longest string containing only characters in characters (a space by default) from the start and end of string.");
214+
scalar_function!(cbrt, Cbrt);
211215
scalar_function!(ceil, Ceil);
212216
scalar_function!(
213217
character_length,
@@ -219,9 +223,14 @@ scalar_function!(char_length, CharacterLength);
219223
scalar_function!(chr, Chr, "Returns the character with the given code.");
220224
scalar_function!(coalesce, Coalesce);
221225
scalar_function!(cos, Cos);
226+
scalar_function!(cosh, Cosh);
227+
scalar_function!(degrees, Degrees);
222228
scalar_function!(exp, Exp);
229+
scalar_function!(factorial, Factorial);
223230
scalar_function!(floor, Floor);
231+
scalar_function!(gcd, Gcd);
224232
scalar_function!(initcap, InitCap, "Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters.");
233+
scalar_function!(lcm, Lcm);
225234
scalar_function!(left, Left, "Returns first n characters in the string, or when n is negative, returns all but last |n| characters.");
226235
scalar_function!(ln, Ln);
227236
scalar_function!(log, Log);
@@ -235,9 +244,16 @@ scalar_function!(
235244
MD5,
236245
"Computes the MD5 hash of the argument, with the result written in hexadecimal."
237246
);
247+
scalar_function!(
248+
nanvl,
249+
Nanvl,
250+
"Computes the MD5 hash of the argument, with the result written in hexadecimal."
251+
);
238252
scalar_function!(octet_length, OctetLength, "Returns number of bytes in the string. Since this version of the function accepts type character directly, it will not strip trailing spaces.");
253+
scalar_function!(pi, Pi);
239254
scalar_function!(power, Power);
240255
scalar_function!(pow, Power);
256+
scalar_function!(radians, Radians);
241257
scalar_function!(regexp_match, RegexpMatch);
242258
scalar_function!(
243259
regexp_replace,
@@ -269,6 +285,7 @@ scalar_function!(sha384, SHA384);
269285
scalar_function!(sha512, SHA512);
270286
scalar_function!(signum, Signum);
271287
scalar_function!(sin, Sin);
288+
scalar_function!(sinh, Sinh);
272289
scalar_function!(
273290
split_part,
274291
SplitPart,
@@ -283,6 +300,7 @@ scalar_function!(
283300
scalar_function!(strpos, Strpos, "Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)");
284301
scalar_function!(substr, Substr);
285302
scalar_function!(tan, Tan);
303+
scalar_function!(tanh, Tanh);
286304
scalar_function!(
287305
to_hex,
288306
ToHex,
@@ -343,6 +361,7 @@ aggregate_function!(var_samp, Variance);
343361
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
344362
m.add_wrapped(wrap_pyfunction!(abs))?;
345363
m.add_wrapped(wrap_pyfunction!(acos))?;
364+
m.add_wrapped(wrap_pyfunction!(acosh))?;
346365
m.add_wrapped(wrap_pyfunction!(approx_distinct))?;
347366
m.add_wrapped(wrap_pyfunction!(alias))?;
348367
m.add_wrapped(wrap_pyfunction!(approx_median))?;
@@ -353,11 +372,14 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
353372
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
354373
m.add_wrapped(wrap_pyfunction!(ascii))?;
355374
m.add_wrapped(wrap_pyfunction!(asin))?;
375+
m.add_wrapped(wrap_pyfunction!(asinh))?;
356376
m.add_wrapped(wrap_pyfunction!(atan))?;
377+
m.add_wrapped(wrap_pyfunction!(atanh))?;
357378
m.add_wrapped(wrap_pyfunction!(atan2))?;
358379
m.add_wrapped(wrap_pyfunction!(avg))?;
359380
m.add_wrapped(wrap_pyfunction!(bit_length))?;
360381
m.add_wrapped(wrap_pyfunction!(btrim))?;
382+
m.add_wrapped(wrap_pyfunction!(cbrt))?;
361383
m.add_wrapped(wrap_pyfunction!(ceil))?;
362384
m.add_wrapped(wrap_pyfunction!(character_length))?;
363385
m.add_wrapped(wrap_pyfunction!(chr))?;
@@ -369,25 +391,30 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
369391
m.add_wrapped(wrap_pyfunction!(concat))?;
370392
m.add_wrapped(wrap_pyfunction!(corr))?;
371393
m.add_wrapped(wrap_pyfunction!(cos))?;
394+
m.add_wrapped(wrap_pyfunction!(cosh))?;
372395
m.add_wrapped(wrap_pyfunction!(count))?;
373396
m.add_wrapped(wrap_pyfunction!(count_star))?;
374397
m.add_wrapped(wrap_pyfunction!(covar))?;
375398
m.add_wrapped(wrap_pyfunction!(covar_pop))?;
376399
m.add_wrapped(wrap_pyfunction!(covar_samp))?;
377400
m.add_wrapped(wrap_pyfunction!(current_date))?;
378401
m.add_wrapped(wrap_pyfunction!(current_time))?;
402+
m.add_wrapped(wrap_pyfunction!(degrees))?;
379403
m.add_wrapped(wrap_pyfunction!(date_bin))?;
380404
m.add_wrapped(wrap_pyfunction!(datepart))?;
381405
m.add_wrapped(wrap_pyfunction!(date_part))?;
382406
m.add_wrapped(wrap_pyfunction!(datetrunc))?;
383407
m.add_wrapped(wrap_pyfunction!(date_trunc))?;
384408
m.add_wrapped(wrap_pyfunction!(digest))?;
385409
m.add_wrapped(wrap_pyf F42D unction!(exp))?;
410+
m.add_wrapped(wrap_pyfunction!(factorial))?;
386411
m.add_wrapped(wrap_pyfunction!(floor))?;
387412
m.add_wrapped(wrap_pyfunction!(from_unixtime))?;
413+
m.add_wrapped(wrap_pyfunction!(gcd))?;
388414
m.add_wrapped(wrap_pyfunction!(grouping))?;
389415
m.add_wrapped(wrap_pyfunction!(in_list))?;
390416
m.add_wrapped(wrap_pyfunction!(initcap))?;
417+
m.add_wrapped(wrap_pyfunction!(lcm))?;
391418
m.add_wrapped(wrap_pyfunction!(left))?;
392419
m.add_wrapped(wrap_pyfunction!(length))?;
393420
m.add_wrapped(wrap_pyfunction!(ln))?;
@@ -403,12 +430,15 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
403430
m.add_wrapped(wrap_pyfunction!(mean))?;
404431
m.add_wrapped(wrap_pyfunction!(median))?;
405432
m.add_wrapped(wrap_pyfunction!(min))?;
433+
m.add_wrapped(wrap_pyfunction!(nanvl))?;
406434
m.add_wrapped(wrap_pyfunction!(now))?;
407435
m.add_wrapped(wrap_pyfunction!(nullif))?;
408436
m.add_wrapped(wrap_pyfunction!(octet_length))?;
409437
m.add_wrapped(wrap_pyfunction!(order_by))?;
438+
m.add_wrapped(wrap_pyfunction!(pi))?;
410439
m.add_wrapped(wrap_pyfunction!(power))?;
411440
m.add_wrapped(wrap_pyfunction!(pow))?;
441+
m.add_wrapped(wrap_pyfunction!(radians))?;
412442
m.add_wrapped(wrap_pyfunction!(random))?;
413443
m.add_wrapped(wrap_pyfunction!(regexp_match))?;
414444
m.add_wrapped(wrap_pyfunction!(regexp_replace))?;
@@ -425,6 +455,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
425455
m.add_wrapped(wrap_pyfunction!(sha512))?;
426456
m.add_wrapped(wrap_pyfunction!(signum))?;
427457
m.add_wrapped(wrap_pyfunction!(sin))?;
458+
m.add_wrapped(wrap_pyfunction!(sinh))?;
428459
m.add_wrapped(wrap_pyfunction!(split_part))?;
429460
m.add_wrapped(wrap_pyfunction!(sqrt))?;
430461
m.add_wrapped(wrap_pyfunction!(starts_with))?;
@@ -436,6 +467,7 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
436467
m.add_wrapped(wrap_pyfunction!(substr))?;
437468
m.add_wrapped(wrap_pyfunction!(sum))?;
438469
m.add_wrapped(wrap_pyfunction!(tan))?;
470+
m.add_wrapped(wrap_pyfunction!(tanh))?;
439471
m.add_wrapped(wrap_pyfunction!(to_hex))?;
440472
m.add_wrapped(wrap_pyfunction!(to_timestamp))?;
441473
m.add_wrapped(wrap_pyfunction!(to_timestamp_millis))?;

0 commit comments

Comments
 (0)
0