8000 #45 Add register_table/deregister_table and expose some public mod by jychen7 · Pull Request #46 · datafusion-contrib/datafusion-python · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Jul 25, 2022. It is now read-only.

#45 Add register_table/deregister_table and expose some public mod #46

Merged
merged 2 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ uuid = { version = "0.8", features = ["v4"] }
mimalloc = { version = "*", default-features = false }

[lib]
name = "_internal"
crate-type = ["cdylib"]
name = "datafusion_python"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crate-type = ["cdylib", "rlib"]

[package.metadata.maturin]
name = "datafusion._internal"
Expand Down
33 changes: 33 additions & 0 deletions datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from datafusion import ExecutionContext
import pyarrow as pa


@pytest.fixture
def ctx():
return ExecutionContext()


@pytest.fixture
def database(ctx, tmp_path):
path = tmp_path / "test.csv"

table = pa.Table.from_arrays(
[
[1, 2, 3, 4],
["a", "b", "c", "d"],
[1.1, 2.2, 3.3, 4.4],
],
names=["int", "str", "float"],
)
pa.csv.write_csv(table, path)

ctx.register_csv("csv", path)
ctx.register_csv("csv1", str(path))
ctx.register_csv(
"csv2",
path,
has_header=True,
delimiter=",",
schema_infer_max_records=10,
)
32 changes: 0 additions & 32 deletions datafusion/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,6 @@
import pyarrow as pa
import pytest

from datafusion import ExecutionContext


@pytest.fixture
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to datafusion/tests/conftest.py for shared usage in test_context and test_sql

def ctx():
return ExecutionContext()


@pytest.fixture
def database(ctx, tmp_path):
path = tmp_path / "test.csv"

table = pa.Table.from_arrays(
[
[1, 2, 3, 4],
["a", "b", "c", "d"],
[1.1, 2.2, 3.3, 4.4],
],
names=["int", "str", "float"],
)
pa.csv.write_csv(table, path)

ctx.register_csv("csv", path)
ctx.register_csv("csv1", str(path))
ctx.register_csv(
"csv2",
path,
has_header=True,
delimiter=",",
schema_infer_max_records=10,
)


def test_basic(ctx, database):
with pytest.raises(KeyError):
Expand Down
27 changes: 19 additions & 8 deletions datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@
# under the License.

import pyarrow as pa
import pytest

from datafusion import ExecutionContext


@pytest.fixture
def ctx():
return ExecutionContext()


def test_register_record_batches(ctx):
Expand Down Expand Up @@ -61,3 +53,22 @@ def test_create_dataframe_registers_unique_table_name(ctx):
# only hexadecimal numbers
for c in tables[0][1:]:
assert c in "0123456789abcdef"


def test_register_table(ctx, database):
default = ctx.catalog()
public = default.database("public")
assert public.names() == {"csv", "csv1", "csv2"}
table = public.table("csv")

ctx.register_table("csv3", table)
assert public.names() == {"csv", "csv1", "csv2", "csv3"}


def test_deregister_table(ctx, database):
default = ctx.catalog()
public = default.database("public")
assert public.names() == {"csv", "csv1", "csv2"}

ctx.deregister_table("csv")
assert public.names() == {"csv1", "csv2"}
7 changes: 1 addition & 6 deletions datafusion/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,11 @@
import pyarrow as pa
import pytest

from datafusion import ExecutionContext, udf
from datafusion import udf

from . import generic as helpers


@pytest.fixture
def ctx():
return ExecutionContext()


def test_no_table(ctx):
with pytest.raises(Exception, match="DataFusion error"):
ctx.sql("SELECT a FROM b").collect()
Expand Down
6 changes: 5 additions & 1 deletion src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub(crate) struct PyDatabase {
}

#[pyclass(name = "Table", module = "datafusion", subclass)]
pub(crate) struct PyTable {
pub struct PyTable {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

table: Arc<dyn TableProvider>,
}

Expand All @@ -58,6 +58,10 @@ impl PyTable {
pub fn new(table: Arc<dyn TableProvider>) -> Self {
Self { table }
}

pub fn table(&self) -> Arc<dyn TableProvider> {
self.table.clone()
}
}

#[pymethods]
Expand Down
16 changes: 15 additions & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion::datasource::MemTable;
use datafusion::execution::context::ExecutionContext;
use datafusion::prelude::CsvReadOptions;

use crate::catalog::PyCatalog;
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
use crate::errors::DataFusionError;
use crate::udf::PyScalarUDF;
Expand Down Expand Up @@ -80,6 +80,20 @@ impl PyExecutionContext {
Ok(df)
}

fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> {
self.ctx
.register_table(name, table.table())
.map_err(DataFusionError::from)?;
Ok(())
}

fn deregister_table(&mut self, name: &str) -> PyResult<()> {
self.ctx
.deregister_table(name)
.map_err(DataFusionError::from)?;
Ok(())
}

fn register_record_batches(
&mut self,
name: &str,
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
use mimalloc::MiMalloc;
use pyo3::prelude::*;

mod catalog;
pub mod catalog;
mod context;
mod dataframe;
mod errors;
pub mod errors;
mod expression;
mod functions;
mod udaf;
mod udf;
mod utils;
pub mod utils;

#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
Expand Down
2 changes: 1 addition & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion::physical_plan::functions::Volatility;
use crate::errors::DataFusionError;

/// Utility to collect rust futures with GIL released
pub(crate) fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
pub fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where
F: Send,
F::Output: Send,
Expand Down
0