8000 feat/making global context accessible for users (#1060) · satwikmishra11/datafusion-python@3dcf7c7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3dcf7c7

Browse files
authored
feat/making global context accessible for users (apache#1060)
* Rename _global_ctx to global_ctx * Add global context to python wrapper code * Update context.py * singleton for global context * formatting * remove udf from import * remove _global_instance * formatting * formatting * unnecessary test * fix test_io.py * ran ruff * ran ruff format
1 parent b194a87 commit 3dcf7c7

File tree

4 files changed

+58
-37
lines changed

4 files changed

+58
-37
lines changed

python/datafusion/context.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,18 @@ def __init__(
496496

497497
self.ctx = SessionContextInternal(config, runtime)
498498

499+
@classmethod
500+
def global_ctx(cls) -> SessionContext:
501+
"""Retrieve the global context as a `SessionContext` wrapper.
502+
503+
Returns:
504+
A `SessionContext` object that wraps the global `SessionContextInternal`.
505+
"""
506+
internal_ctx = SessionContextInternal.global_ctx()
507+
wrapper = cls()
508+
wrapper.ctx = internal_ctx
509+
return wrapper
510+
499511
def enable_url_table(self) -> SessionContext:
500512
"""Control if local files can be queried as tables.
501513

python/datafusion/io.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121

2222
from typing import TYPE_CHECKING
2323

24+
from datafusion.context import SessionContext
2425
from datafusion.dataframe import DataFrame
2526

26-
from ._internal import SessionContext as SessionContextInternal
27-
2827
if TYPE_CHECKING:
2928
import pathlib
3029

@@ -68,16 +67,14 @@ def read_parquet(
6867
"""
6968
if table_partition_cols is None:
7069
table_partition_cols = []
71-
return DataFrame(
72-
SessionContextInternal._global_ctx().read_parquet(
73-
str(path),
74-
table_partition_cols,
75-
parquet_pruning,
76-
file_extension,
77-
skip_metadata,
78-
schema,
79-
file_sort_order,
80-
)
70+
return SessionContext.global_ctx().read_parquet(
71+
str(path),
72+
table_partition_cols,
73+
parquet_pruning,
74+
file_extension,
75+
skip_metadata,
76+
schema,
77+
file_sort_order,
8178 10000
)
8279

8380

@@ -110,15 +107,13 @@ def read_json(
110107
"""
111108
if table_partition_cols is None:
112109
table_partition_cols = []
113-
return DataFrame(
114-
SessionContextInternal._global_ctx().read_json(
115-
str(path),
116-
schema,
117-
schema_infer_max_records,
118-
file_extension,
119-
table_partition_cols,
120-
file_compression_type,
121-
)
110+
return SessionContext.global_ctx().read_json(
111+
str(path),
112+
schema,
113+
schema_infer_max_records,
114+
file_extension,
115+
table_partition_cols,
116+
file_compression_type,
122117
)
123118

124119

@@ -161,17 +156,15 @@ def read_csv(
161156

162157
path = [str(p) for p in path] if isinstance(path, list) else str(path)
163158

164-
return DataFrame(
165-
SessionContextInternal._global_ctx().read_csv(
166-
path,
167-
schema,
168-
has_header,
169-
delimiter,
170-
schema_infer_max_records,
171-
file_extension,
172-
table_partition_cols,
173-
file_compression_type,
174-
)
159+
return SessionContext.global_ctx().read_csv(
160+
path,
161+
schema,
162+
has_header,
163+
delimiter,
164+
schema_infer_max_records,
165+
file_extension,
166+
table_partition_cols,
167+
file_compression_type,
175168
)
176169

177170

@@ -198,8 +191,6 @@ def read_avro(
198191
"""
199192
if file_partition_cols is None:
200193
file_partition_cols = []
201-
return DataFrame(
202-
SessionContextInternal._global_ctx().read_avro(
203-
str(path), schema, file_partition_cols, file_extension
204-
)
194+
return SessionContext.global_ctx().read_avro(
195+
str(path), schema, file_partition_cols, file_extension
205196
)

python/tests/test_context.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,3 +632,21 @@ def test_sql_with_options_no_statements(ctx):
632632
options = SQLOptions().with_allow_statements(allow=False)
633633
with pytest.raises(Exception, match="SetVariable"):
634634
ctx.sql_with_options(sql, options=options)
635+
636+
637+
@pytest.fixture
638+
def batch():
639+
return pa.RecordBatch.from_arrays(
640+
[pa.array([4, 5, 6])],
641+
names=["a"],
642+
)
643+
644+
645+
def test_create_dataframe_with_global_ctx(batch):
646+
ctx = SessionContext.global_ctx()
647+
648+
df = ctx.create_dataframe([[batch]])
649+
650+
result = df.collect()[0].column(0)
651+
652+
assert result == pa.array([4, 5, 6])

src/context.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ impl PySessionContext {
308308

309309
#[classmethod]
310310
#[pyo3(signature = ())]
311-
fn _global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
311+
fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult<Self> {
312312
Ok(Self {
313313
ctx: get_global_ctx().clone(),
314314
})

0 commit comments

Comments
 (0)
0