forked from litestar-org/sqlspec
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprotocols.py
More file actions
507 lines (372 loc) · 15.5 KB
/
protocols.py
File metadata and controls
507 lines (372 loc) · 15.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
8000
div>
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
"""Runtime-checkable protocols for type safety and runtime checks.
This module provides protocols that can be used for static type checking
and runtime isinstance() checks.
"""
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from typing_extensions import Self
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Iterator
from pathlib import Path
from sqlglot import exp
from sqlspec.typing import ArrowRecordBatch, ArrowTable
__all__ = (
"BytesConvertibleProtocol",
"DictProtocol",
"ExpressionWithAliasProtocol",
"FilterAppenderProtocol",
"FilterParameterProtocol",
"HasExpressionProtocol",
"HasExpressionsProtocol",
"HasLimitProtocol",
"HasOffsetProtocol",
"HasOrderByProtocol",
"HasParameterBuilderProtocol",
"HasSQLGlotExpressionProtocol",
"HasSQLMethodProtocol",
"HasToStatementProtocol",
"HasWhereProtocol",
"IndexableRow",
"IterableParameters",
"ObjectStoreItemProtocol",
"ObjectStoreProtocol",
"ParameterValueProtocol",
"SQLBuilderProtocol",
"SelectBuilderProtocol",
"StackResultProtocol",
"SupportsArrowResults",
"WithMethodProtocol",
)
@runtime_checkable
class IndexableRow(Protocol):
"""Protocol for row types that support index access."""
def __getitem__(self, index: int) -> Any:
"""Get item by index."""
...
def __len__(self) -> int:
"""Get length of the row."""
...
@runtime_checkable
class IterableParameters(Protocol):
"""Protocol for parameter sequences."""
def __iter__(self) -> Any:
"""Iterate over parameters."""
...
def __len__(self) -> int:
"""Get number of parameters."""
...
@runtime_checkable
class WithMethodProtocol(Protocol):
"""Protocol for objects with a with_ method (SQLGlot expressions)."""
def with_(self, *args: Any, **kwargs: Any) -> Any:
"""Add WITH clause to expression."""
...
@runtime_checkable
class HasWhereProtocol(Protocol):
"""Protocol for SQL expressions that support WHERE clauses."""
def where(self, *args: Any, **kwargs: Any) -> Any:
"""Add WHERE clause to expression."""
...
@runtime_checkable
class HasLimitProtocol(Protocol):
"""Protocol for SQL expressions that support LIMIT clauses."""
def limit(self, *args: Any, **kwargs: Any) -> Any:
"""Add LIMIT clause to expression."""
...
@runtime_checkable
class HasOffsetProtocol(Protocol):
"""Protocol for SQL expressions that support OFFSET clauses."""
def offset(self, *args: Any, **kwargs: Any) -> Any:
"""Add OFFSET clause to expression."""
...
@runtime_checkable
class HasOrderByProtocol(Protocol):
"""Protocol for SQL expressions that support ORDER BY clauses."""
def order_by(self, *args: Any, **kwargs: Any) -> Any:
"""Add ORDER BY clause to expression."""
...
@runtime_checkable
class HasExpressionsProtocol(Protocol):
"""Protocol for SQL expressions that have an expressions attribute."""
expressions: Any
@runtime_checkable
class HasSQLMethodProtocol(Protocol):
"""Protocol for objects that have a sql() method for rendering SQL."""
def sql(self, *args: Any, **kwargs: Any) -> str:
"""Render object to SQL string."""
...
@runtime_checkable
class FilterParameterProtocol(Protocol):
"""Protocol for filter objects that can extract parameters."""
def extract_parameters(self) -> tuple[list[Any], dict[str, Any]]:
"""Extract parameters from the filter."""
...
@runtime_checkable
class FilterAppenderProtocol(Protocol):
"""Protocol for filter objects that can append to SQL statements."""
def append_to_statement(self, sql: Any) -> Any:
"""Append this filter to a SQL statement."""
...
@runtime_checkable
class ParameterValueProtocol(Protocol):
"""Protocol for parameter objects with value and type_hint attributes."""
value: Any
type_hint: str
@runtime_checkable
class DictProtocol(Protocol):
"""Protocol for objects with a __dict__ attribute."""
__dict__: dict[str, Any]
@runtime_checkable
class BytesConvertibleProtocol(Protocol):
"""Protocol for objects that can be converted to bytes."""
def __bytes__(self) -> bytes:
"""Convert object to bytes."""
...
@runtime_checkable
class ExpressionWithAliasProtocol(Protocol):
"""Protocol for SQL expressions that support aliasing with as_() method."""
def as_(self, alias: str, **kwargs: Any) -> "exp.Alias":
"""Create an aliased expression."""
...
@runtime_checkable
class ObjectStoreItemProtocol(Protocol):
"""Protocol for object store items with path/key attributes."""
path: str
key: "str | None"
@runtime_checkable
class ObjectStoreProtocol(Protocol):
"""Protocol for object storage operations."""
protocol: str
backend_type: str
def __init__(self, uri: str, **kwargs: Any) -> None:
return
def read_bytes(self, path: "str | Path", **kwargs: Any) -> bytes:
"""Read bytes from an object."""
return b""
def write_bytes(self, path: "str | Path", data: bytes, **kwargs: Any) -> None:
"""Write bytes to an object."""
return
def read_text(self, path: "str | Path", encoding: str = "utf-8", **kwargs: Any) -> str:
"""Read text from an object."""
return ""
def write_text(self, path: "str | Path", data: str, encoding: str = "utf-8", **kwargs: Any) -> None:
"""Write text to an object."""
return
def exists(self, path: "str | Path", **kwargs: Any) -> bool:
"""Check if an object exists."""
return False
def delete(self, path: "str | Path", **kwargs: Any) -> None:
"""Delete an object."""
return
def copy(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None:
"""Copy an object."""
return
def move(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None:
"""Move an object."""
return
def list_objects(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> list[str]:
"""List objects with optional prefix."""
return []
def glob(self, pattern: str, **kwargs: Any) -> list[str]:
"""Find objects matching a glob pattern."""
return []
def is_object(self, path: "str | Path") -> bool:
"""Check if path points to an object."""
return False
def is_path(self, path: "str | Path") -> bool:
"""Check if path points to a prefix (directory-like)."""
return False
def get_metadata(self, path: "str | Path", **kwargs: Any) -> dict[str, Any]:
"""Get object metadata."""
return {}
def read_arrow(self, path: "str | Path", **kwargs: Any) -> "ArrowTable":
"""Read an Arrow table from storage."""
msg = "Arrow reading not implemented"
raise NotImplementedError(msg)
def write_arrow(self, path: "str | Path", table: "ArrowTable", **kwargs: Any) -> None:
"""Write an Arrow table to storage."""
msg = "Arrow writing not implemented"
raise NotImplementedError(msg)
def stream_arrow(self, pattern: str, **kwargs: Any) -> "Iterator[ArrowRecordBatch]":
"""Stream Arrow record batches from matching objects."""
msg = "Arrow streaming not implemented"
raise NotImplementedError(msg)
async def read_bytes_async(self, path: "str | Path", **kwargs: Any) -> bytes:
"""Async read bytes from an object."""
msg = "Async operations not implemented"
raise NotImplementedError(msg)
async def write_bytes_async(self, path: "str | Path", data: bytes, **kwargs: Any) -> None:
"""Async write bytes to an object."""
msg = "Async operations not implemented"
raise NotImplementedError(msg)
async def read_text_async(self, path: "str | Path", encoding: str = "utf-8", **kwargs: Any) -> str:
"""Async read text from an object."""
msg = "Async operations not implemented"
raise NotImplementedError(msg)
async def write_text_async(self, path: "str | Path", data: str, encoding: str = "utf-8", **kwargs: Any) -> None:
"""Async write text to an object."""
msg = "Async operat
8000
ions not implemented"
raise NotImplementedError(msg)
async def exists_async(self, path: "str | Path", **kwargs: Any) -> bool:
"""Async check if an object exists."""
msg = "Async operations not implemented"
raise NotImplementedError(msg)
async def delete_async(self, path: "str | Path", **kwargs: Any) -> None:
"""Async delete an object."""
msg = "Async operations not implemented"
raise NotImplementedError(msg)
async def list_objects_async(self, prefix: str = "", recursive: bool = True, **kwargs: Any) -> list[str]:
"""Async list objects with optional prefix."""
msg = "Async operations not implemented"
raise NotImplementedError(msg)
async def copy_async(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None:
"""Async copy an object."""
msg = "Async operations not implemented"
raise NotImplementedError(msg)
async def move_async(self, source: "str | Path", destination: "str | Path", **kwargs: Any) -> None:
"""Async move an object."""
msg = "Async operations not implemented"
raise NotImplementedError(msg)
async def get_metadata_async(self, path: "str | Path", **kwargs: Any) -> dict[str, Any]:
"""Async get object metadata."""
msg = "Async operations not implemented"
raise NotImplementedError(msg)
async def read_arrow_async(self, path: "str | Path", **kwargs: Any) -> "ArrowTable":
"""Async read an Arrow table from storage."""
msg = "Async arrow reading not implemented"
raise NotImplementedError(msg)
async def write_arrow_async(self, path: "str | Path", table: "ArrowTable", **kwargs: Any) -> None:
"""Async write an Arrow table to storage."""
msg = "Async arrow writing not implemented"
raise NotImplementedError(msg)
def stream_arrow_async(self, pattern: str, **kwargs: Any) -> "AsyncIterator[ArrowRecordBatch]":
"""Async stream Arrow record batches from matching objects."""
msg = "Async arrow streaming not implemented"
raise NotImplementedError(msg)
@runtime_checkable
class HasSQLGlotExpressionProtocol(Protocol):
"""Protocol for objects with a sqlglot_expression property."""
@property
def sqlglot_expression(self) -> "exp.Expression | None":
"""Return the SQLGlot expression for this object."""
...
@runtime_checkable
class HasParameterBuilderProtocol(Protocol):
"""Protocol for objects that can add parameters."""
def add_parameter(self, value: Any, name: "str | None" = None) -> tuple[Any, str]:
"""Add a parameter to the builder."""
...
def get_expression(self) -> "exp.Expression | None":
"""Return the underlying SQLGlot expression."""
...
def set_expression(self, expression: "exp.Expression") -> None:
"""Replace the underlying SQLGlot expression."""
...
@runtime_checkable
class HasExpressionProtocol(Protocol):
"""Protocol for objects with an _expression attribute."""
_expression: "exp.Expression | None"
@runtime_checkable
class HasToStatementProtocol(Protocol):
"""Protocol for objects with a to_statement method."""
def to_statement(self) -> Any:
"""Convert to SQL statement."""
...
@runtime_checkable
class SQLBuilderProtocol(Protocol):
"""Protocol for SQL query builders."""
_expression: "exp.Expression | None"
_parameters: dict[str, Any]
_parameter_counter: int
_columns: Any # Optional attribute for some builders
_table: Any # Optional attribute for some builders
_with_ctes: Any # Optional attribute for some builders
dialect: Any
dialect_name: "str | None"
@property
def parameters(self) -> dict[str, Any]:
"""Public access to query parameters."""
...
def add_parameter(self, value: Any, name: "str | None" = None) -> tuple[Any, str]:
"""Add a parameter to the builder."""
...
def _generate_unique_parameter_name(self, base_name: str) -> str:
"""Generate a unique parameter name."""
...
def _parameterize_expression(self, expression: "exp.Expression") -> "exp.Expression":
"""Replace literal values in an expression with bound parameters."""
...
def build(self) -> "exp.Expression | Any":
"""Build and return the final expression."""
...
def _merge_sql_object_parameters(self, sql_obj: Any) -> None:
"""Merge parameters from SQL objects into the builder."""
...
def _build_final_expression(self, *, copy: bool = False) -> "exp.Expression":
"""Return the expression with attached CTEs."""
...
def _spawn_like_self(self) -> "Self":
"""Create a new builder with matching configuration."""
...
def get_expression(self) -> "exp.Expression | None":
"""Return the underlying SQLGlot expression."""
...
def set_expression(self, expression: "exp.Expression") -> None:
"""Replace the underlying SQLGlot expression."""
...
def generate_unique_parameter_name(self, base_name: str) -> str:
"""Generate a unique parameter name exposed via public API."""
...
class SelectBuilderProtocol(SQLBuilderProtocol, Protocol):
"""Protocol for SELECT query builders."""
def select(self, *columns: "str | exp.Expression") -> Self:
"""Add SELECT columns to the query."""
...
@runtime_checkable
class SupportsArrowResults(Protocol):
"""Protocol for adapters that support Arrow result format.
Adapters implementing this protocol can return query results in Apache Arrow
format via the select_to_arrow() method, enabling zero-copy data transfer and
efficient integration with data science tools.
"""
def select_to_arrow(
self,
statement: Any,
/,
*parameters: Any,
statement_config: Any | None = None,
return_format: str = "table",
native_only: bool = False,
batch_size: int | None = None,
arrow_schema: Any | None = None,
**kwargs: Any,
) -> "ArrowTable | ArrowRecordBatch":
"""Execute query and return results as Apache Arrow Table or RecordBatch.
Args:
statement: SQL statement to execute.
*parameters: Query parameters and filters.
statement_config: Optional statement configuration override.
return_format: Output format - "table", "reader", or "batches".
native_only: If True, raise error when native Arrow path unavailable.
batch_size: Chunk size for streaming modes.
arrow_schema: Optional target Arrow schema for type casting.
**kwargs: Additional keyword arguments.
Returns:
ArrowResult containing Arrow data.
"""
...
@runtime_checkable
class StackResultProtocol(Protocol):
"""Protocol describing stack execution results."""
result: Any
rows_affected: int
error: Exception | None
warning: Any | None
metadata: Mapping[str, Any] | None
result_type: str
@property
def rows(self) -> Sequence[Any]: ...
def is_error(self) -> bool: ...
def is_sql_result(self) -> bool: ...
def is_arrow_result(self) -> bool: ...
def get_result(self) -> Any: ...
You can’t perform that action at this time.