8000 feat: reads using global ctx (#982) · satwikmishra11/datafusion-python@acd7040 · GitHub
[go: up one dir, main page]

Skip to content

Commit acd7040

Browse files
feat: reads using global ctx (apache#982)
* feat: reads using global ctx * Add text to io methods to describe the context they are using --------- Co-authored-by: Tim Saucer <timsaucer@gmail.com>
1 parent 9027b4d commit acd7040

File tree

6 files changed

+319
-2
lines changed

6 files changed

+319
-2
lines changed

python/datafusion/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
Expr,
4646
WindowFrame,
4747
)
48+
from .io import read_avro, read_csv, read_json, read_parquet
4849
from .plan import ExecutionPlan, LogicalPlan
4950
from .record_batch import RecordBatch, RecordBatchStream
5051
from .udf import Accumulator, AggregateUDF, ScalarUDF, WindowUDF
@@ -81,6 +82,10 @@
8182
"functions",
8283
"object_store",
8384
"substrait",
85+
"read_parquet",
86+
"read_avro",
87+
"read_csv",
88+
"read_json",
8489
]
8590

8691

python/datafusion/io.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""IO read functions using global context."""
19+
20+
import pathlib
21+
22+
import pyarrow
23+
24+
from datafusion.dataframe import DataFrame
25+
from datafusion.expr import Expr
26+
27+
from ._internal import SessionContext as SessionContextInternal
28+
29+
30+
def read_parquet(
31+
path: str | pathlib.Path,
32+
table_partition_col B41A s: list[tuple[str, str]] | None = None,
33+
parquet_pruning: bool = True,
34+
file_extension: str = ".parquet",
35+
skip_metadata: bool = True,
36+
schema: pyarrow.Schema | None = None,
37+
file_sort_order: list[list[Expr]] | None = None,
38+
) -> DataFrame:
39+
"""Read a Parquet source into a :py:class:`~datafusion.dataframe.Dataframe`.
40+
41+
This function will use the global context. Any functions or tables registered
42+
with another context may not be accessible when used with a DataFrame created
43+
using this function.
44+
45+
Args:
46+
path: Path to the Parquet file.
47+
table_partition_cols: Partition columns.
48+
parquet_pruning: Whether the parquet reader should use the predicate
49+
to prune row groups.
50+
file_extension: File extension; only files with this extension are
51+
selected for data input.
52+
skip_metadata: Whether the parquet reader should skip any metadata
53+
that may be in the file schema. This can help avoid schema
54+
conflicts due to metadata.
55+
schema: An optional schema representing the parquet files. If None,
56+
the parquet reader will try to infer it based on data in the
57+
file.
58+
file_sort_order: Sort order for the file.
59+
60+
Returns:
61+
DataFrame representation of the read Parquet files
62+
"""
63+
if table_partition_cols is None:
64+
table_partition_cols = []
65+
return DataFrame(
66+
SessionContextInternal._global_ctx().read_parquet(
67+
str(path),
68+
table_partition_cols,
69+
parquet_pruning,
70+
file_extension,
71+
skip_metadata,
72+
schema,
73+
file_sort_order,
74+
)
75+
)
76+
77+
78+
def read_json(
79+
path: str | pathlib.Path,
80+
schema: pyarrow.Schema | None = None,
81+
schema_infer_max_records: int = 1000,
82+
file_extension: str = ".json",
83+
table_partition_cols: list[tuple[str, str]] | None = None,
84+
file_compression_type: str | None = None,
85+
) -> DataFrame:
86+
"""Read a line-delimited JSON data source.
87+
88+
This function will use the global context. Any functions or tables registered
89+
with another context may not be accessible when used with a DataFrame created
90+
using this function.
91+
92+
Args:
93+
path: Path to the JSON file.
94+
schema: The data source schema.
95+
schema_infer_max_records: Maximum number of rows to read from JSON
96+
files for schema inference if needed.
97+
file_extension: File extension; only files with this extension are
98+
selected for data input.
99+
table_partition_cols: Partition columns.
100+
file_compression_type: File compression type.
101+
102+
Returns:
103+
DataFrame representation of the read JSON files.
104+
"""
105+
if table_partition_cols is None:
106+
table_partition_cols = []
107+
return DataFrame(
108+
SessionContextInternal._global_ctx().read_json(
109+
str(path),
110+
schema,
111+
schema_infer_max_records,
112+
file_extension,
113+
table_partition_cols,
114+
file_compression_type,
115+
)
116+
)
117+
118+
119+
def read_csv(
120+
path: str | pathlib.Path | list[str] | list[pathlib.Path],
121+
schema: pyarrow.Schema | None = None,
122+
has_header: bool = True,
123+
delimiter: str = ",",
124+
schema_infer_max_records: int = 1000,
125+
file_extension: str = ".csv",
126+
table_partition_cols: list[tuple[str, str]] | None = None,
127+
file_compression_type: str | None = None,
128+
) -> DataFrame:
129+
"""Read a CSV data source.
130+
131+
This function will use the global context. Any functions or tables registered
132+
with another context may not be accessible when used with a DataFrame created
133+
using this function.
134+
135+
Args:
136+
path: Path to the CSV file
137+
schema: An optional schema representing the CSV files. If None, the
138+
CSV reader will try to infer it based on data in file.
139+
has_header: Whether the CSV file have a header. If schema inference
140+
is run on a file with no headers, default column names are
141+
created.
142+
delimiter: An optional column delimiter.
143+
schema_infer_max_records: Maximum number of rows to read from CSV
144+
files for schema inference if needed.
145+
file_extension: File extension; only files with this extension are
146+
selected for data input.
147+
table_partition_cols: Partition columns.
148+
file_compression_type: File compression type.
149+
150+
Returns:
151+
DataFrame representation of the read CSV files
152+
"""
153+
if table_partition_cols is None:
154+
table_partition_cols = []
155+
156+
path = [str(p) for p in path] if isinstance(path, list) else str(path)
157+
158+
return DataFrame(
159+
SessionContextInternal._global_ctx().read_csv(
160+
path,
161+
schema,
162+
has_header,
163+
delimiter,
164+
schema_infer_max_records,
165+
file_extension,
166+
table_partition_cols,
167+
file_compression_type,
168+
)
169+
)
170+
171+
172+
def read_avro(
173+
path: str | pathlib.Path,
174+
schema: pyarrow.Schema | None = None,
175+
file_partition_cols: list[tuple[str, str]] | None = None,
176+
file_extension: str = ".avro",
177+
) -> DataFrame:
178+
"""Create a :py:class:`DataFrame` for reading Avro data source.
179+
180+
This function will use the global context. Any functions or tables registered
181+
with another context may not be accessible when used with a DataFrame created
182+
using this function.
183+
184+
Args:
185+
path: Path to the Avro file.
186+
schema: The data source schema.
187+
file_partition_cols: Partition columns.
188+
file_extension: File extension to select.
189+
190+
Returns:
191+
DataFrame representation of the read Avro file
192+
"""
193+
if file_partition_cols is None:
194+
file_partition_cols = []
195+
return DataFrame(
196+
SessionContextInternal._global_ctx().read_avro(
197+
str(path), schema, file_partition_cols, file_extension
198+
)
199+
)

python/tests/test_io.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import os
18+
import pathlib
19+
20+
import pyarrow as pa
21+
from datafusion import column
22+
from datafusion.io import read_avro, read_csv, read_json, read_parquet
23+
24+
25+
def test_read_json_global_ctx(ctx):
26+
path = os.path.dirname(os.path.abspath(__file__))
27+
28+
# Default
29+
test_data_path = os.path.join(path, "data_test_context", "data.json")
30+
df = read_json(test_data_path)
31+
result = df.collect()
32+
33+
assert result[0].column(0) == pa.array(["a", "b", "c"])
34+
assert result[0].column(1) == pa.array([1, 2, 3])
35+
36+
# Schema
37+
schema = pa.schema(
38+
[
39+
pa.field("A", pa.string(), nullable=True),
40+
]
41+
)
42+
df = read_json(test_data_path, schema=schema)
43+
result = df.collect()
44+
45+
assert result[0].column(0) == pa.array(["a", "b", "c"])
46+
assert result[0].schema == schema
47+
48+
# File extension
49+
test_data_path = os.path.join(path, "data_test_context", "data.json")
50+
df = read_json(test_data_path, file_extension=".json")
51+
result = df.collect()
52+
53+
assert result[0].column(0) == pa.array(["a", "b", "c"])
54+
assert result[0].column(1) == pa.array([1, 2, 3])
55+
56+
57+
def test_read_parquet_global():
58+
parquet_df = read_parquet(path="parquet/data/alltypes_plain.parquet")
59+
parquet_df.show()
60+
assert parquet_df is not None
61+
62+
path = pathlib.Path.cwd() / "parquet/data/alltypes_plain.parquet"
63+
parquet_df = read_parquet(path=path)
64+
assert parquet_df is not None
65+
66+
67+
def test_read_csv():
68+
csv_df = read_csv(path="testing/data/csv/aggregate_test_100.csv")
69+
csv_df.select(column("c1")).show()
70+
71+
72+
def test_read_csv_list():
73+
csv_df = read_csv(path=["testing/data/csv/aggregate_test_100.csv"])
74+
expected = csv_df.count() * 2
75+
76+
double_csv_df = read_csv(
77+
path=[
78+
"testing/data/csv/aggregate_test_100.csv",
79+
"testing/data/csv/aggregate_test_100.csv",
80+
]
81+
)
82+
actual = double_csv_df.count()
83+
84+
double_csv_df.select(column("c1")).show()
85+
assert actual == expected
86+
87+
88+
def test_read_avro():
89+
avro_df = read_avro(path="testing/data/avro/alltypes_plain.avro")
90+
avro_df.show()
91+
assert avro_df is not None
92+
93+
path = pathlib.Path.cwd() / "testing/data/avro/alltypes_plain.avro"
94+
avro_df = read_avro(path=path)
95+
assert avro_df is not None

python/tests/test_wrapper_coverage.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def missing_exports(internal_obj, wrapped_obj) -> None:
3434
return
3535

3636
for attr in dir(internal_obj):
37+
if attr in ["_global_ctx"]:
38+
continue
3739
assert attr in dir(wrapped_obj)
3840

3941
internal_attr = getattr(internal_obj, attr)

src/context.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use crate::store::StorageContexts;
4444
use crate::udaf::PyAggregateUDF;
4545
use crate::udf::PyScalarUDF;
4646
use crate::udwf::PyWindowUDF;
47-
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
47+
use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_for_future};
4848
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
4949
use datafusion::arrow::pyarrow::PyArrowType;
5050
use datafusion::arrow::record_batch::RecordBatch;
@@ -69,7 +69,7 @@ use datafusion::prelude::{
6969
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
7070
};
7171
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
72-
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
72+
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
7373
use tokio::task::JoinHandle;
7474

7575
/// Configuration options for a SessionContext
@@ -306,6 +306,14 @@ impl PySessionContext {
306306
})
307307
}
308308

309+
#[classmethod]
310+
#[pyo3(signature = ())]
311+
fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
312+
Ok(Self {
313+
ctx: get_global_ctx().clone(),
314+
})
315+
}
316+
309317
/// Register an object store with the given name
310318
#[pyo3(signature = (scheme, store, host=None))]
311319
pub fn register_object_store(

src/utils.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use crate::errors::{PyDataFusionError, PyDataFusionResult};
1919
use crate::TokioRuntime;
20+
use datafusion::execution::context::SessionContext;
2021
use datafusion::logical_expr::Volatility;
2122
use pyo3::exceptions::PyValueError;
2223
use pyo3::prelude::*;
@@ -37,6 +38,13 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
3738
RUNTIME.get_or_init(|| TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
3839
}
3940

41+
/// Utility to get the Global Datafussion CTX
42+
#[inline]
43+
pub(crate) fn get_global_ctx() -> &'static SessionContext {
44+
static CTX: OnceLock<SessionContext> = OnceLock::new();
45+
CTX.get_or_init(|| SessionContext::new())
46+
}
47+
4048
/// Utility to collect rust futures with GIL released
4149
pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
4250
where

0 commit comments

Comments
 (0)
0