|
1 | 1 | import inspect
|
2 | 2 | import six
|
3 |
| -from functools import total_ordering |
| 3 | +from functools import total_ordering, wraps |
4 | 4 | from graphql.core.type import (
|
5 | 5 | GraphQLField,
|
6 | 6 | GraphQLList,
|
@@ -49,12 +49,26 @@ def contribute_to_class(self, cls, name):
|
49 | 49 | cls._meta.add_field(self)
|
50 | 50 |
|
51 | 51 | def resolve(self, instance, args, info):
|
| 52 | + resolve_fn = self.get_resolve_fn() |
| 53 | + if resolve_fn: |
| 54 | + return resolve_fn(instance, args, info) |
| 55 | + else: |
| 56 | + return instance.get_field(self.field_name) |
| 57 | + |
| 58 | + @memoize |
| 59 | + def get_resolve_fn(self): |
52 | 60 | if self.resolve_fn:
|
53 |
| - resolve_fn = self.resolve_fn |
| 61 | + return self.resolve_fn |
54 | 62 | else:
|
55 |
| - resolve_fn = lambda root, args, info: root.resolve( |
56 |
| - self.field_name, args, info) |
57 |
| - return resolve_fn(instance, args, info) |
| 63 | + custom_resolve_fn_name = 'resolve_%s' % self.field_name |
| 64 | + if hasattr(self.object_type, custom_resolve_fn_name): |
| 65 | + resolve_fn = getattr(self.object_type, custom_resolve_fn_name) |
| 66 | + |
| 67 | + @wraps(resolve_fn) |
| 68 | + def custom_resolve_fn(instance, args, info): |
| 69 | + custom_fn = getattr(instance, custom_resolve_fn_name) |
| 70 | + return custom_fn(args, info) |
| 71 | + return custom_resolve_fn |
58 | 72 |
|
59 | 73 | def get_object_type(self, schema):
|
60 | 74 | field_type = self.field_type
|
@@ -110,11 +124,18 @@ def internal_field(self, schema):
|
110 | 124 | if not internal_type:
|
111 | 125 | raise Exception("Internal type for field %s is None" % self)
|
112 | 126 |
|
| 127 | + resolve_fn = self.get_resolve_fn() |
| 128 | + if resolve_fn: |
| 129 | + @wraps(resolve_fn) |
| 130 | + def resolver(*args): |
| 131 | + return self.resolve(*args) |
| 132 | + else: |
| 133 | + resolver = self.resolve |
113 | 134 | return GraphQLField(
|
114 | 135 | internal_type,
|
115 | 136 | description=self.description,
|
116 | 137 | args=self.args,
|
117 |
| - resolver=self.resolve, |
| 138 | + resolver=resolver, |
118 | 139 | )
|
119 | 140 |
|
120 | 141 | def __str__(self):
|
@@ -144,7 +165,7 @@ def __lt__(self, other):
|
144 | 165 | return NotImplemented
|
145 | 166 |
|
146 | 167 | def __hash__(self):
|
147 |
| - return hash(self.creation_counter) |
| 168 | + return hash((self.creation_counter, self.object_type)) |
148 | 169 |
|
149 | 170 | def __copy__(self):
|
150 | 171 | # We need to avoid hitting __reduce__, so define this
|
|
0 commit comments