8000 added equality checks · googleapis/python-firestore@ae00f1d · GitHub
[go: up one dir, main page]

Skip to content

Commit ae00f1d

Browse files
committed
added equality checks
1 parent 3cd826b commit ae00f1d

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

google/cloud/firestore_v1/pipeline_expressions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,12 @@ class Constant(Expr, Generic[CONSTANT_TYPE]):
292292
def __init__(self, value: CONSTANT_TYPE):
293293
self.value: CONSTANT_TYPE = value
294294

295+
def __eq__(self, other):
296+
if not isinstance(other, Constant):
297+
return other == self.value
298+
else:
299+
return other.value == self.value
300+
295301
@staticmethod
296302
def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]:
297303
"""Creates a constant expression from a Python value."""
@@ -310,6 +316,12 @@ class ListOfExprs(Expr):
310316
def __init__(self, exprs: List[Expr]):
311317
self.exprs: list[Expr] = exprs
312318

319+
def __eq__(self, other):
320+
if not isinstance(other, ListOfExprs):
321+
return False
322+
else:
323+
return other.exprs == self.exprs
324+
313325
def __repr__(self):
314326
return f"{self.__class__.__name__}({self.exprs})"
315327

@@ -324,6 +336,12 @@ def __init__(self, name: str, params: Sequence[Expr]):
324336
self.name = name
325337
self.params = list(params)
326338

339+
def __eq__(self, other):
340+
if not isinstance(other, Function):
341+
return False
342+
else:
343+
return other.name == self.name and other.params == self.params
344+
327345
def __repr__(self):
328346
return f"{self.__class__.__name__}({', '.join([repr(p) for p in self.params])})"
329347

@@ -339,6 +357,12 @@ def _to_pb(self):
339357
class Selectable(Expr):
340358
"""Base class for expressions that can be selected or aliased in projection stages."""
341359

360+
def __eq__(self, other):
361+
if not isinstance(other, type(self)):
362+
return False
363+
else:
364+
return other._to_map() == self._to_map()
365+
342366
@abstractmethod
343367
def _to_map(self) -> tuple[str, Value]:
344368
"""

tests/unit/v1/test_pipeline_expressions.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,24 @@ def test_repr(self, input_val, expected):
157157
assert repr_string == expected
158158

159159

160+
@pytest.mark.parametrize("first,second,expected", [
161+
(expr.Constant.of(1), expr.Constant.of(2), False),
162+
(expr.Constant.of(1), expr.Constant.of(1), True),
163+
(expr.Constant.of(1), 1, True),
164+
(expr.Constant.of(1), 2, False),
165+
(expr.Constant.of("1"), 1, False),
166+
(expr.Constant.of("1"), "1", True),
167+
(expr.Constant.of(None), expr.Constant.of(0), False),
168+
(expr.Constant.of(None), expr.Constant.of(None), True),
169+
(expr.Constant.of([1,2,3]), expr.Constant.of([1,2,3]), True),
170+
(expr.Constant.of([1,2,3]), expr.Constant.of([1,2]), False),
171+
(expr.Constant.of([1,2,3]), [1,2,3], True),
172+
(expr.Constant.of([1,2,3]), object(), False),
173+
])
174+
def test_equality(self, first, second, expected):
175+
assert (first == second) is expected
176+
177+
160178
class TestListOfExprs:
161179
def test_to_pb(self):
162180
instance = expr.ListOfExprs([expr.Constant(1), expr.Constant(2)])
@@ -178,6 +196,20 @@ def test_repr(self):
178196
empty_repr_string = repr(empty_instance)
179197
assert empty_repr_string == "ListOfExprs([])"
180198

199+
@pytest.mark.parametrize("first,second,expected", [
200+
(expr.ListOfExprs([]), expr.ListOfExprs([]), True),
201+
(expr.ListOfExprs([]), expr.ListOfExprs([expr.Constant(1)]), False),
202+
(expr.ListOfExprs([expr.Constant(1)]), expr.ListOfExprs([]), False),
203+
(expr.ListOfExprs([expr.Constant(1)]), expr.ListOfExprs([expr.Constant(1)]), True),
204+
(expr.ListOfExprs([expr.Constant(1)]), expr.ListOfExprs([expr.Constant(2)]), False),
205+
(expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), expr.ListOfExprs([expr.Constant(1), expr.Constant(2)]), True),
206+
(expr.ListOfExprs([expr.Constant(1)]), [expr.Constant(1)], False),
207+
(expr.ListOfExprs([expr.Constant(1)]), [1], False),
208+
(expr.ListOfExprs([expr.Constant(1)]), object(), False),
209+
])
210+
def test_equality(self, first, second, expected):
211+
assert (first == second) is expected
212+
181213

182214
class TestSelectable:
183215
def test_ctor(self):
@@ -194,6 +226,13 @@ def test_value_from_selectables(self):
194226
assert result.map_value.fields["field1"].field_reference_value == "field1"
195227
assert result.map_value.fields["field2"].field_reference_value == "field2"
196228

229+
@pytest.mark.parametrize("first,second,expected", [
230+
(expr.Field.of("field1"), expr.Field.of("field1"), True),
231+
(expr.Field.of("field1"), expr.Field.of("field2"), False),
232+
])
233+
def test_equality(self, first, second, expected):
234+
assert (first == second) is expected
235+
197236
class TestField:
198237
def test_repr(self):
199238
instance = expr.Field.of("field1")
@@ -217,6 +256,17 @@ def test_to_map(self):
217256

218257

219258
class TestFilterCondition:
259+
260+
@pytest.mark.parametrize("first,second,expected", [
261+
(expr.IsNaN(expr.Field.of("field1")), expr.IsNaN(expr.Field.of("field1")), True),
262+
(expr.IsNaN(expr.Field.of("real")), expr.IsNaN(expr.Field.of("fale")), False),
263+
(expr.Gt(0, 1), expr.Gt(0, 1), True),
264+
(expr.Gt(0, 1), expr.Gt(1, 0), False),
265+
(expr.Gt(0, 1), expr.Lt(0, 1), False),
266+
])
267+
def test_equality(self, first, second, expected):
268+
assert (first == second) is expected
269+
220270
def test__from_query_filter_pb_composite_filter_or(self, mock_client):
221271
"""
222272
test composite OR filters

0 commit comments

Comments
 (0)
0