10000 chore: set validation and typehint (#983) · davisp/datafusion-python@63b13da · GitHub
[go: up one dir, main page]

Skip to content

Commit 63b13da

Browse files
authored
chore: set validation and typehint (apache#983)
1 parent 85fe35c commit 63b13da

File tree

4 files changed

+36
-23
lines changed

4 files changed

+36
-23
lines changed

python/datafusion/context.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ def __arrow_c_array__( # noqa: D105
6363
) -> tuple[object, object]: ...
6464

6565

66+
class TableProviderExportable(Protocol):
67+
"""Type hint for object that has __datafusion_table_provider__ PyCapsule.
68+
69+
https://datafusion.apache.org/python/user-guide/io/table_provider.html
70+
"""
71+
72+
def __datafusion_table_provider__(self) -> object: ... # noqa: D105
73+
74+
6675
class SessionConfig:
6776
"""Session configuration options."""
6877

@@ -685,7 +694,9 @@ def deregister_table(self, name: str) -> None:
685694
"""Remove a table from the session."""
686695
self.ctx.deregister_table(name)
687696

688-
def register_table_provider(self, name: str, provider: Any) -> None:
697+
def register_table_provider(
698+
self, name: str, provider: TableProviderExportable
699+
) -> None:
689700
"""Register a table provider.
690701
691702
This table provider must have a method called ``__datafusion_table_provider__``

src/context.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use crate::store::StorageContexts;
4343
use crate::udaf::PyAggregateUDF;
4444
use crate::udf::PyScalarUDF;
4545
use crate::udwf::PyWindowUDF;
46-
use crate::utils::{get_tokio_runtime, wait_for_future};
46+
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
4747
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
4848
use datafusion::arrow::pyarrow::PyArrowType;
4949
use datafusion::arrow::record_batch::RecordBatch;
@@ -576,7 +576,7 @@ impl PySessionContext {
576576
if provider.hasattr("__datafusion_table_provider__")? {
577577
let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?;
578578
let capsule = capsule.downcast::<PyCapsule>()?;
579-
// validate_pycapsule(capsule, "arrow_array_stream")?;
579+
validate_pycapsule(capsule, "datafusion_table_provider")?;
580580

581581
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
582582
let provider: ForeignTableProvider = provider.into();

src/dataframe.rs

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use crate::expr::sort_expr::to_sort_expressions;
4444
use crate::physical_plan::PyExecutionPlan;
4545
use crate::record_batch::PyRecordBatchStream;
4646
use crate::sql::logical::PyLogicalPlan;
47-
use crate::utils::{get_tokio_runtime, wait_for_future};
47+
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
4848
use crate::{
4949
errors::DataFusionError,
5050
expr::{sort_expr::PySortExpr, PyExpr},
@@ -724,22 +724,3 @@ fn record_batch_into_schema(
724724

725725
RecordBatch::try_new(schema, data_arrays)
726726
}
727-
728-
fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
729-
let capsule_name = capsule.name()?;
730-
if capsule_name.is_none() {
731-
return Err(PyValueError::new_err(
732-
"Expected schema PyCapsule to have name set.",
733-
));
734-
}
735-
736-
let capsule_name = capsule_name.unwrap().to_str()?;
737-
if capsule_name != name {
738-
return Err(PyValueError::new_err(format!(
739-
"Expected name '{}' in PyCapsule, instead got '{}'",
740-
name, capsule_name
741-
)));
742-
}
743-
744-
Ok(())
745-
}

src/utils.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
use crate::errors::DataFusionError;
1919
use crate::TokioRuntime;
2020
use datafusion::logical_expr::Volatility;
21+
use pyo3::exceptions::PyValueError;
2122
use pyo3::prelude::*;
23+
use pyo3::types::PyCapsule;
2224
use std::future::Future;
2325
use std::sync::OnceLock;
2426
use tokio::runtime::Runtime;
@@ -58,3 +60,22 @@ pub(crate) fn parse_volatility(value: &str) -> Result<Volatility, DataFusionErro
5860
}
5961
})
6062
}
63+
64+
pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyResult<()> {
65+
let capsule_name = capsule.name()?;
66+
if capsule_name.is_none() {
67+
return Err(PyValueError::new_err(
68+
"Expected schema PyCapsule to have name set.",
69+
));
70+
}
71+
72+
let capsule_name = capsule_name.unwrap().to_str()?;
73+
if capsule_name != name {
74+
return Err(PyValueError::new_err(format!(
75+
"Expected name '{}' in PyCapsule, instead got '{}'",
76+
name, capsule_name
77+
)));
78+
}
79+
80+
Ok(())
81+
}

0 commit comments

Comments
 (0)
0