@@ -156,55 +156,64 @@ def check_fields_input_spec(self):
156
156
157
157
"""
158
158
fields = attr_fields (self )
159
- names = []
160
- require_to_check = {}
161
- for fld in fields :
162
- mdata = fld .metadata
163
- # checking if the mandatory field is provided
164
- if getattr (self , fld .name ) is attr .NOTHING :
165
- if mdata .get ("mandatory" ):
166
- # checking if the mandatory field is provided elsewhere in the xor list
167
- in_exclusion_list = mdata .get ("xor" ) is not None
168
- alreday_populated = in_exclusion_list and [
169
- getattr (self , el )
170
- for el in mdata ["xor" ]
171
- if (getattr (self , el ) is not attr .NOTHING )
172
- ]
173
- if (
174
- alreday_populated
175
- ): # another input satisfies mandatory attribute via xor condition
176
- continue
177
- else :
178
- raise AttributeError (
179
- f"{ fld .name } is mandatory, but no value provided"
180
- )
181
- else :
182
- continue
183
- names .append (fld .name )
184
159
185
- # checking if fields meet the xor and requires are
186
- if "xor" in mdata :
187
- if [el for el in mdata ["xor" ] if (el in names and el != fld .name )]:
160
+ for field in fields :
161
+ field_is_mandatory = bool (field .metadata .get ("mandatory" ))
162
+ field_is_unset = getattr (self , field .name ) is attr .NOTHING
163
+
164
+ if field_is_unset and not field_is_mandatory :
165
+ continue
166
+
167
+ # Collect alternative fields associated with this field.
168
+ alternative_fields = {
169
+ name : getattr (self , name ) is not attr .NOTHING
170
+ for name in field .metadata .get ("xor" , [])
171
+ if name != field .name
172
+ }
173
+ alternatives_are_set = any (alternative_fields .values ())
174
+
175
+ # Raise error if no field in mandatory alternative group is set.
176
+ if field_is_unset :
177
+ if alternatives_are_set :
178
+ continue
179
+ message = f"{ field .name } is mandatory and unset."
180
+ if alternative_fields :
188
181
raise AttributeError (
189
- f"{ fld .name } is mutually exclusive with { mdata ['xor' ]} "
182
+ message [:- 1 ]
183
+ + f", but no alternative provided by { list (alternative_fields )} ."
190
184
)
185
+ else :
186
+ raise AttributeError (message )
187
+
188
+ # Raise error if multiple alternatives are set.
189
+ elif alternatives_are_set :
190
+ set_alternative_fields = [
191
+ name for name , is_set in alternative_fields .items () if is_set
192
+ ]
193
+ raise AttributeError (
194
+ f"{ field .name } is mutually exclusive with { set_alternative_fields } "
195
+ )
191
196
192
- if "requires" in mdata :
193
- if [el for el in mdata ["requires" ] if el not in names ]:
194
- # will check after adding all fields to names
195
- require_to_check [fld .name ] = mdata ["requires" ]
197
+ # Collect required fields associated with this field.
198
+ required_fields = {
199
+ name : getattr (self , name ) is not attr .NOTHING
200
+ for name in field .metadata .get ("requires" , [])
201
+ if name != field .name
202
+ }
203
+
204
+ # Raise error if any required field is unset.
205
+ if not all (required_fields .values ()):
206
+ unset_required_fields = [
207
+ name for name , is_set in required_fields .items () if not is_set
208
+ ]
209
+ raise AttributeError (f"{ field .name } requires { unset_required_fields } " )
196
210
197
211
if (
198
- fld .type in [File , Directory ]
199
- or "pydra.engine.specs.File" in str (fld .type )
200
- or "pydra.engine.specs.Directory" in str (fld .type )
212
+ field .type in [File , Directory ]
213
+ or "pydra.engine.specs.File" in str (field .type )
214
+ or "pydra.engine.specs.Directory" in str (field .type )
201
215
):
202
- self ._file_check (fld )
203
-
204
- for nm , required in require_to_check .items ():
205
- required_notfound = [el for el in required if el not in names ]
206
- if required_notfound :
207
- raise AttributeError (f"{ nm } requires { required_notfound } " )
216
+ self ._file_check (field )
208
217
209
218
def _file_check (self , field ):
210
219
"""checking if the file exists"""
0 commit comments