8000 fixes: #2943, allow overriding validates for inheritance · sqlalchemy/sqlalchemy@eab0c56 · GitHub
[go: up one dir, main page]

Skip to content

Commit eab0c56

Browse files
committed
fixes: #2943, allow overriding validates for inheritance
Applied the patch mentioned in #2943, to allow overriding the validates method of a given Model, Added tests for same in test_validators. If a Child class overrides the parent class validates method only child class validator will be invoked unless child class explicitly invokes parent class validator
1 parent 527fac5 commit eab0c56
8000

File tree

3 files changed

+95
-14
lines changed

3 files changed

+95
-14
lines changed

lib/sqlalchemy/orm/mapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4333,6 +4333,10 @@ def validates(
43334333
modify or replace the value before proceeding. The function should
43344334
otherwise return the given value.
43354335
4336+
Overriding validator method will invoke child validator method, in
4337+
order to also invoke parent validator method as well child validator
4338+
can explicitly invoke parent class validator(s).
4339+
43364340
Note that a validator for a collection **cannot** issue a load of that
43374341
collection within the validation routine - this usage raises
43384342
an assertion to avoid recursion overflows. This is a reentrant

lib/sqlalchemy/orm/strategies.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,30 +75,23 @@ def _register_attribute(
7575
impl_class=None,
7676
**kw,
7777
):
78-
listen_hooks = []
78+
pre_validate_hooks = []
79+
post_validate_hooks = []
7980

8081
uselist = useobject and prop.uselist
8182

8283
if useobject and prop.single_parent:
83-
listen_hooks.append(single_parent_validator)
84-
85-
if prop.key in prop.parent.validators:
86-
fn, opts = prop.parent.validators[prop.key]
87-
listen_hooks.append(
88-
lambda desc, prop: orm_util._validator_events(
89-
desc, prop.key, fn, **opts
90-
)
91-
)
84+
pre_validate_hooks.append(single_parent_validator)
9285

9386
if useobject:
94-
listen_hooks.append(unitofwork.track_cascade_events)
87+
post_validate_hooks.append(unitofwork.track_cascade_events)
9588

9689
# need to assemble backref listeners
9790
# after the singleparentvalidator, mapper validator
9891
if useobject:
9992
backref = prop.back_populates
10093
if backref and prop._effective_sync_backref:
101-
listen_hooks.append(
94+
post_validate_hooks.append(
10295
lambda desc, prop: attributes.backref_listeners(
10396
desc, backref, uselist
10497
)
@@ -114,7 +107,6 @@ def _register_attribute(
114107
# mapper here might not be prop.parent; also, a subclass mapper may
115108
# be called here before a superclass mapper. That is, can't depend
116109
# on mappers not already being set up so we have to check each one.
117-
118110
for m in mapper.self_and_descendants:
119111
if prop is m._props.get(
120112
prop.key
@@ -140,7 +132,16 @@ def _register_attribute(
140132
**kw,
141133
)
142134

143-
for hook in listen_hooks:
135+
for hook in pre_validate_hooks:
136+
hook(desc, prop)
137+
138+
for super_m in m.iterate_to_root():
139+
if prop.key in super_m.validators:
140+
fn, opts< 8000 /span> = super_m.validators[prop.key]
141+
orm_util._validator_events(desc, prop.key, fn, **opts)
142+
break
143+
144+
for hook in post_validate_hooks:
144145
hook(desc, prop)
145146

146147

test/orm/test_validators.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
from unittest.mock import call
22
from unittest.mock import Mock
33

4+
from sqlalchemy import Column
45
from sqlalchemy import exc
6+
from sqlalchemy import Integer
7+
from sqlalchemy import String
58
from sqlalchemy import testing
69
from sqlalchemy.orm import collections
10+
from sqlalchemy.orm import declarative_base
711
from sqlalchemy.orm import relationship
812
from sqlalchemy.orm import validates
913
from sqlalchemy.testing import assert_raises
1014
from sqlalchemy.testing import assert_raises_message
15+
from sqlalchemy.testing import expect_raises_message
1116
from sqlalchemy.testing import eq_
1217
from sqlalchemy.testing import ne_
1318
from sqlalchemy.testing.entities import ComparableEntity
@@ -447,3 +452,74 @@ def validate_user(self, key, item):
447452
call("user", User(addresses=[])),
448453
],
449454
)
455+
456+
def test_validator_inheritance_override_validator(self):
457+
Base = declarative_base()
458+
459+
class A(Base):
460+
__tablename__ = "a"
461+
id = Column(Integer, primary_key=True)
462+
data = Column(String)
463+
foo = Column(String)
464+
465+
@validates("data")
466+
def validate_data(self, key, value):
467+
return "Call from A : " + value
468+
469+
@validates("foo")
470+
def validate_foo(self, key, value):
471+
ne_(value, "exclude for A", "Message raised from A")
472+
return value
473+
474+
class B(A):
475+
foo2 = Column(String)
476+
bar = Column(String)
477+
478+
@validates("data")
479+
def validate_data(self, key, value):
480+
return "Call from B : " + value
481+
482+
@validates("foo")
483+
def validate_foo(self, key, value):
484+
# Test Calling both validators
485+
value = super().validate_foo(key, value)
486+
ne_(value, "exclude for B", "Message raised from B")
487+
return value
488+
489+
@validates("foo2", "bar")
490+
def validate_foobar(self, key, value):
491+
if key == "foo2":
492+
return value + "_"
493+
return "_" + value
494+
495+
class C(B):
496+
@validates("foo2", "bar")
497+
def validate_foobar(self, key, value):
498+
if key == "foo2":
499+
return value + "-"
500+
return "-" + value
501+
502+
obj = A(data="ed")
503+
eq_(obj.data, "Call from A : ed")
504+
with expect_raises_message(AssertionError, "Message raised from A"):
505+
obj.foo = "exclude for A"
506+
obj.foo = "exclude for B"
507+
508+
obj = B(data="ed")
509+
eq_(obj.data, "Call from B : ed")
510+
# Should call A's Validator
511+
with expect_raises_message(AssertionError, "Message raised from A"):
512+
obj.foo = "exclude for A"
513+
# Should call B's Validator
514+
with expect_raises_message(AssertionError, "Message raised from B"):
515+
obj.foo = "exclude for B"
516+
obj.foo = "Some other value"
517+
518+
obj.foo2 = "foo"
519+
obj.bar = "bar"
520+
eq_(obj.foo2 + obj.bar, "foo__bar")
521+
522+
obj = C(data="ed")
523+
obj.foo2 = "foo"
524+
obj.bar = "bar"
525+
eq_(obj.foo2 + obj.bar, "foo--bar")

0 commit comments

Comments
 (0)
0