forked from litestar-org/sqlspec
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_update.py
More file actions
161 lines (125 loc) · 4.78 KB
/
_update.py
File metadata and controls
161 lines (125 loc) · 4.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""UPDATE statement builder.
Provides a fluent interface for building SQL UPDATE queries with
parameter binding and validation.
"""
from typing import TYPE_CHECKING, Any, cast
from sqlglot import exp
from typing_extensions import Self
from sqlspec.builder._base import QueryBuilder, SafeQuery
from sqlspec.builder._dml import UpdateFromClauseMixin, UpdateSetClauseMixin, UpdateTableClauseMixin
from sqlspec.builder._join import build_join_clause
from sqlspec.builder._select import ReturningClauseMixin, WhereClauseMixin
from sqlspec.core import SQLResult
from sqlspec.exceptions import SQLBuilderError
if TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
from sqlspec.builder._select import Select
from sqlspec.protocols import SQLBuilderProtocol
__all__ = ("Update",)
class Update(
QueryBuilder,
WhereClauseMixin,
ReturningClauseMixin,
UpdateSetClauseMixin,
UpdateFromClauseMixin,
UpdateTableClauseMixin,
):
"""Builder for UPDATE statements.
Constructs SQL UPDATE statements with parameter binding and validation.
Example:
```python
update_query = (
Update()
.table("users")
.set_(name="John Doe")
.set_(email="john@example.com")
.where("id = 1")
)
update_query = (
Update("users").set_(name="John Doe").where("id = 1")
)
update_query = (
Update()
.table("users")
.set_(status="active")
.where_eq("id", 123)
)
update_query = (
Update()
.table("users", "u")
.set_(name="Updated Name")
.from_("profiles", "p")
.where("u.id = p.user_id AND p.is_verified = true")
)
```
"""
__slots__ = ("_table",)
_expression: exp.Expression | None
def __init__(self, table: str | None = None, **kwargs: Any) -> None:
"""Initialize UPDATE with optional table.
Args:
table: Target table name
**kwargs: Additional QueryBuilder arguments
"""
super().__init__(**kwargs)
self._initialize_expression()
if table:
self.table(table)
@property
def _expected_result_type(self) -> "type[SQLResult]":
"""Return the expected result type for this builder."""
return SQLResult
def _create_base_expression(self) -> exp.Update:
"""Create a base UPDATE expression.
Returns:
A new sqlglot Update expression with empty clauses.
"""
return exp.Update(this=None, expressions=[], joins=[])
def join(
self,
table: "str | exp.Expression | Select",
on: "str | exp.Expression",
alias: "str | None" = None,
join_type: str = "INNER",
) -> "Self":
"""Add JOIN clause to the UPDATE statement.
Args:
table: The table name, expression, or subquery to join.
on: The JOIN condition.
alias: Optional alias for the joined table.
join_type: Type of join (INNER, LEFT, RIGHT, FULL).
Returns:
The current builder instance for method chaining.
Raises:
SQLBuilderError: If the current expression is not an UPDATE statement.
"""
if self._expression is None or not isinstance(self._expression, exp.Update):
msg = "Cannot add JOIN clause to non-UPDATE expression."
raise SQLBuilderError(msg)
join_expr = build_join_clause(cast("SQLBuilderProtocol", self), table, on, alias, join_type)
if not self._expression.args.get("joins"):
self._expression.set("joins", [])
self._expression.args["joins"].append(join_expr)
return self
def build(self, dialect: "DialectType" = None) -> "SafeQuery":
"""Build the UPDATE query with validation.
Args:
dialect: Optional dialect override for SQL generation.
Returns:
SafeQuery: The built query with SQL and parameters.
Raises:
SQLBuilderError: If no table is set or expression is not an UPDATE.
"""
if self._expression is None:
msg = "UPDATE expression not initialized."
raise SQLBuilderError(msg)
if not isinstance(self._expression, exp.Update):
msg = "No UPDATE expression to build or expression is of the wrong type."
raise SQLBuilderError(msg)
if self._expression.this is None:
msg = "No table specified for UPDATE statement."
raise SQLBuilderError(msg)
if not self._expression.args.get("expressions"):
msg = "At least one SET clause must be specified for UPDATE statement."
raise SQLBuilderError(msg)
return super().build(dialect=dialect)