10000 feature: Set table name from ctx functions (#260) · chenqin/arrow-datafusion-python@9fc5332 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9fc5332

Browse files
authored
feature: Set table name from ctx functions (apache#260)
1 parent a3c108f commit 9fc5332

File tree

4 files changed

+80
-23
lines changed

4 files changed

+80
-23
lines changed

datafusion/tests/test_context.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,21 @@ def test_create_dataframe_registers_unique_table_name(ctx):
9696
assert c in "0123456789abcdef"
9797

9898

99+
def test_create_dataframe_registers_with_defined_table_name(ctx):
100+
# create a RecordBatch and register it as memtable
101+
batch = pa.RecordBatch.from_arrays(
102+
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
103+
names=["a", "b"],
104+
)
105+
106+
df = ctx.create_dataframe([[batch]], name="tbl")
107+
tables = list(ctx.tables())
108+
109+
assert df
110+
assert len(tables) == 1
111+
assert tables[0] == "tbl"
112+
113+
99114
def test_from_arrow_table(ctx):
100115
# create a PyArrow table
101116
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
@@ -112,6 +127,19 @@ def test_from_arrow_table(ctx):
112127
assert df.collect()[0].num_rows == 3
113128

114129

130+
def test_from_arrow_table_with_name(ctx):
131+
# create a PyArrow table
132+
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
133+
table = pa.Table.from_pydict(data)
134+
135+
# convert to DataFrame with optional name
136+
df = ctx.from_arrow_table(table, name="tbl")
137+
tables = list(ctx.tables())
138+
139+
assert df
140+
assert tables[0] == "tbl"
141+
142+
115143
def test_from_pylist(ctx):
116144
# create a dataframe from Python list
117145
data = [

examples/sql-using-python-udaf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def evaluate(self) -> pa.Scalar:
6262
ctx = SessionContext()
6363

6464
# Create a datafusion DataFrame from a Python dictionary
65-
source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]})
65+
source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]}, name="t")
6666
# Dataframe:
6767
# +---+---+
6868
# | a | b |
@@ -76,9 +76,8 @@ def evaluate(self) -> pa.Scalar:
7676
ctx.register_udaf(my_udaf)
7777

7878
# Query the DataFrame using SQL
79-
table_name = ctx.catalog().database().names().pop()
8079
result_df = ctx.sql(
81-
f"select a, my_accumulator(b) as b_aggregated from {table_name} group by a order by a"
80+
"select a, my_accumulator(b) as b_aggregated from t group by a order by a"
8281
)
8382
# Dataframe:
8483
# +---+--------------+

examples/sql-using-python-udf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def is_null(array: pa.Array) -> pa.Array:
3838
ctx = SessionContext()
3939

4040
# Create a datafusion DataFrame from a Python dictionary
41-
source_df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]})
41+
ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]}, name="t")
4242
# Dataframe:
4343
# +---+---+
4444
# | a | b |
@@ -52,8 +52,7 @@ def is_null(array: pa.Array) -> pa.Array:
5252
ctx.register_udf(is_null_arr)
5353

5454
# Query the DataFrame using SQL
55-
table_name = ctx.catalog().database().names().pop()
56-
result_df = ctx.sql(f"select a, is_null(b) as b_is_null from {table_name}")
55+
result_df = ctx.sql("select a, is_null(b) as b_is_null from t")
5756
# Dataframe:
5857
# +---+-----------+
5958
# | a | b_is_null |

src/context.rs

Lines changed: 48 additions & 17 deletions
2851
Original file line numberDiff line numberDiff line change
@@ -276,23 +276,29 @@ impl PySessionContext {
276276
fn create_dataframe(
277277
&mut self,
278278
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
279+
name: Option<&str>,
279280
py: Python,
280281
) -> PyResult<PyDataFrame> {
281282
let schema = partitions.0[0][0].schema();
282283
let table = MemTable::try_new(schema, partitions.0).map_err(DataFusionError::from)?;
283284

284-
// generate a random (unique) name for this table
285+
// generate a random (unique) name for this table if none is provided
285286
// table name cannot start with numeric digit
286-
let name = "c".to_owned()
287-
+ Uuid::new_v4()
288-
.simple()
289-
.encode_lower(&mut Uuid::encode_buffer());
287+
let table_name = match name {
288+
Some(val) => val.to_owned(),
289+
None => {
290+
"c".to_owned()
291+
+ Uuid::new_v4()
292+
.simple()
293+
.encode_lower(&mut Uuid::encode_buffer())
294+
}
295+
};
290296

291297
self.ctx
292-
.register_table(&*name, Arc::new(table))
298+
.register_table(&*table_name, Arc::new(table))
293299
.map_err(DataFusionError::from)?;
294300

295-
let table = wait_for_future(py, self._table(&name)).map_err(DataFusionError::from)?;
301+
let table = wait_for_future(py, self._table(&table_name)).map_err(DataFusionError::from)?;
296302

297303
let df = PyDataFrame::new(table);
298304
Ok(df)
@@ -305,37 +311,52 @@ impl PySessionContext {
305311

306312
/// Construct datafusion dataframe from Python list
307313
#[allow(clippy::wrong_self_convention)]
308-
fn from_pylist(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
314+
fn from_pylist(
315+
&mut self,
316+
data: PyObject,
317+
name: Option<&str>,
318+
_py: Python,
319+
) -> PyResult<PyDataFrame> {
309320
Python::with_gil(|py| {
310321
// Instantiate pyarrow Table object & convert to Arrow Table
311322
let table_class = py.import("pyarrow")?.getattr("Table")?;
312323
let args = PyTuple::new(py, &[data]);
313324
let table = table_class.call_method1("from_pylist", args)?.into();
314325

315326
// Convert Arrow Table to datafusion DataFrame
316-
let df = self.from_arrow_table(table, py)?;
327+
let df = self.from_arrow_table(table, name, py)?;
317328
Ok(df)
318329
})
319330
}
320331

321332
/// Construct datafusion dataframe from Python dictionary
322333
#[allow(clippy::wrong_self_convention)]
323-
fn from_pydict(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
334+
fn from_pydict(
335+
&mut self,
336+
data: PyObject,
337+
name: Option<&str>,
338+
_py: Python,
339+
) -> PyResult<PyDataFrame> {
324340
Python::with_gil(|py| {
325341
// Instantiate pyarrow Table object & convert to Arrow Table
326342
let table_class = py.import("pyarrow")?.getattr("Table")?;
327343
let args = PyTuple::new(py, &[data]);
328344
let table = table_class.call_method1("from_pydict", args)?.into();
329345

330346
// Convert Arrow Table to datafusion DataFrame
331-
let df = self.from_arrow_table(table, py)?;
347+
let df = self.from_arrow_table(table, name, py)?;
332348
Ok(df)
333349
})
334350
}
335351

336352
/// Construct datafusion dataframe from Arrow Table
337353
#[allow(clippy::wrong_self_convention)]
338-
fn from_arrow_table(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
354+
fn from_arrow_table(
355+
&mut self,
356+
data: PyObject,
357+
name: Option<&str>,
358+
_py: Python,
359+
) -> PyResult<PyDataFrame> {
339360
Python::with_gil(|py| {
340361
// Instantiate pyarrow Table object & convert to batches
341362
let table = data.call_method0(py, "to_batches")?;
@@ -345,34 +366,44 @@ impl PySessionContext {
345366
// here we need to wrap the vector of record batches in an additional vector
346367
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
347368
let list_of_batches = PyArrowType::try_from(vec![batches.0])?;
348-
self.create_dataframe(list_of_batches, py)
369+
self.create_dataframe(list_of_batches, name, py)
349370
})
350371
} E377
351372

352373
/// Construct datafusion dataframe from pandas
353374
#[allow(clippy::wrong_self_convention)]
354-
fn from_pandas(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
375+
fn from_pandas(
376+
&mut self,
377+
data: PyObject,
378+
name: Option<&str>,
379+
_py: Python,
380+
) -> PyResult<PyDataFrame> {
355381
Python::with_gil(|py| {
356382
// Instantiate pyarrow Table object & convert to Arrow Table
357383
let table_class = py.import("pyarrow")?.getattr("Table")?;
358384
let args = PyTuple::new(py, &[data]);
359385
let table = table_class.call_method1("from_pandas", args)?.into();
360386

361387
// Convert Arrow Table to datafusion DataFrame
362-
let df = self.from_arrow_table(table, py)?;
388+
let df = self.from_arrow_table(table, name, py)?;
363389
Ok(df)
364390
})
365391
}
366392

367393
/// Construct datafusion dataframe from polars
368394
#[allow(clippy::wrong_self_convention)]
369-
fn from_polars(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
395+
fn from_polars(
396+
&mut self,
397+
data: PyObject,
398+
name: Option<&str>,
399+
_py: Python,
400+
) -> PyResult<PyDataFrame> {
370401
Python::with_gil(|py| {
371402
// Convert Polars dataframe to Arrow Table
372403
let table = data.call_method0(py, "to_arrow")?;
373404

374405
// Convert Arrow Table to datafusion DataFrame
375-
let df = self.from_arrow_table(table, py)?;
406+
let df = self.from_arrow_table(table, name, py)?;
376407
Ok(df)
377408
})
378409
}

0 commit comments

Comments
 (0)
0