8000 feature: Set table name from ctx functions by simicd · Pull Request #260 · apache/datafusion-python · GitHub
[go: up one dir, main page]

Skip to content

feature: Set table name from ctx functions #260

New issue

Have a question about this 8000 project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ def test_create_dataframe_registers_unique_table_name(ctx):
assert c in "0123456789abcdef"


def test_create_dataframe_registers_with_defined_table_name(ctx):
# create a RecordBatch and register it as memtable
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)

df = ctx.create_dataframe([[batch]], name="tbl")
tables = list(ctx.tables())

assert df
assert len(tables) == 1
assert tables[0] == "tbl"


def test_from_arrow_table(ctx):
# create a PyArrow table
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
Expand All @@ -112,6 +127,19 @@ def test_from_arrow_table(ctx):
assert df.collect()[0].num_rows == 3


def test_from_arrow_table_with_name(ctx):
# create a PyArrow table
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
table = pa.Table.from_pydict(data)

# convert to DataFrame with optional name
df = ctx.from_arrow_table(table, name="tbl")
tables = list(ctx.tables())

assert df
assert tables[0] == "tbl"


def test_from_pylist(ctx):
# create a dataframe from Python list
data = [
Expand Down
5 changes: 2 additions & 3 deletions examples/sql-using-python-udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def evaluate(self) -> pa.Scalar:
ctx = SessionContext()

# Create a datafusion DataFrame from a Python dictionary
source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]})
source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]}, name="t")
# Dataframe:
# +---+---+
# | a | b |
Expand All @@ -76,9 +76,8 @@ def evaluate(self) -> pa.Scalar:
ctx.register_udaf(my_udaf)

# Query the DataFrame using SQL
table_name = ctx.catalog().database().names().pop()
result_df = ctx.sql(
f"select a, my_accumulator(b) as b_aggregated from {table_name} group by a order by a"
"select a, my_accumulator(b) as b_aggregated from t group by a order by a"
)
# Dataframe:
# +---+--------------+
Expand Down
5 changes: 2 additions & 3 deletions examples/sql-using-python-udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def is_null(array: pa.Array) -> pa.Array:
ctx = SessionContext()

# Create a datafusion DataFrame from a Python dictionary
source_df = ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]})
ctx.from_pydict({"a": [1, 2, 3], "b": [4, None, 6]}, name="t")
# Dataframe:
# +---+---+
# | a | b |
Expand All @@ -52,8 +52,7 @@ def is_null(array: pa.Array) -> pa.Array:
ctx.register_udf(is_null_arr)

# Query the DataFrame using SQL
table_name = ctx.catalog().database().names().pop()
result_df = ctx.sql(f"select a, is_null(b) as b_is_null from {table_name}")
result_df = ctx.sql("select a, is_null(b) as b_is_null from t")
# Dataframe:
# +---+-----------+
# | a | b_is_null |
Expand Down
65 changes: 48 additions & 17 deletions src/context.rs
67E6
Original file line number Diff line number Diff line change
Expand Up @@ -276,23 +276,29 @@ impl PySessionContext {
fn create_dataframe(
&mut self,
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
name: Option<&str>,
py: Python,
) -> PyResult<PyDataFrame> {
let schema = partitions.0[0][0].schema();
let table = MemTable::try_new(schema, partitions.0).map_err(DataFusionError::from)?;

// generate a random (unique) name for this table
// generate a random (unique) name for this table if none is provided
// table name cannot start with numeric digit
let name = "c".to_owned()
+ Uuid::new_v4()
.simple()
.encode_lower(&mut Uuid::encode_buffer());
let table_name = match name {
Some(val) => val.to_owned(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

None => {
"c".to_owned()
+ Uuid::new_v4()
.simple()
.encode_lower(&mut Uuid::encode_buffer())
}
};

self.ctx
.register_table(&*name, Arc::new(table))
.register_table(&*table_name, Arc::new(table))
.map_err(DataFusionError::from)?;

let table = wait_for_future(py, self._table(&name)).map_err(DataFusionError::from)?;
let table = wait_for_future(py, self._table(&table_name)).map_err(DataFusionError::from)?;

let df = PyDataFrame::new(table);
Ok(df)
Expand All @@ -305,37 +311,52 @@ impl PySessionContext {

/// Construct datafusion dataframe from Python list
#[allow(clippy::wrong_self_convention)]
fn from_pylist(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
fn from_pylist(
&mut self,
data: PyObject,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[data]);
let table = table_class.call_method1("from_pylist", args)?.into();

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, py)?;
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
}

/// Construct datafusion dataframe from Python dictionary
#[allow(clippy::wrong_self_convention)]
fn from_pydict(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
fn from_pydict(
&mut self,
data: PyObject,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[data]);
let table = table_class.call_method1("from_pydict", args)?.into();

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, py)?;
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
}

/// Construct datafusion dataframe from Arrow Table
#[allow(clippy::wrong_self_convention)]
fn from_arrow_table(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
fn from_arrow_table(
&mut self,
data: PyObject,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to batches
let table = data.call_method0(py, "to_batches")?;
Expand All @@ -345,34 +366,44 @@ impl PySessionContext {
// here we need to wrap the vector of record batches in an additional vector
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
let list_of_batches = PyArrowType::try_from(vec![batches.0])?;
self.create_dataframe(list_of_batches, py)
self.create_dataframe(list_of_batches, name, py)
})
}

/// Construct datafusion dataframe from pandas
#[allow(clippy::wrong_self_convention)]
fn from_pandas(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
fn from_pandas(
&mut self,
data: PyObject,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import("pyarrow")?.getattr("Table")?;
let args = PyTuple::new(py, &[data]);
let table = table_class.call_method1("from_pandas", args)?.into();

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, py)?;
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
}

/// Construct datafusion dataframe from polars
#[allow(clippy::wrong_self_convention)]
fn from_polars(&mut self, data: PyObject, _py: Python) -> PyResult<PyDataFrame> {
fn from_polars(
&mut self,
data: PyObject,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Convert Polars dataframe to Arrow Table
let table = data.call_method0(py, "to_arrow")?;

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, py)?;
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
}
Expand Down
0