8000 Disable slow gradcheck for nn.Transformer ModuleInfo (#145531) · pytorch/pytorch@c7ca1df · GitHub
[go: up one dir, main page]

Skip to content

Commit c7ca1df

Browse files
soulitzerpytorchmergebot
authored andcommitted
Disable slow gradcheck for nn.Transformer ModuleInfo (#145531)
Fixes #117140 Pull Request resolved: #145531 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #145520
1 parent 9e0ee15 commit c7ca1df

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

test/test_modules.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,11 +482,19 @@ def fn_to_gradcheck(*flat_input_and_params):
482482
output_flattened = torch.utils._pytree.tree_leaves(output)
483483
return output_flattened
484484

485+
def do_check(flat_input):
486+
self.assertTrue(
487+
check(
488+
fn_to_gradcheck,
489+
flat_input,
490+
nondet_tol=gradcheck_nondet_tol,
491+
fast_mode=module_info.gradcheck_fast_mode
492+
))
493+
485494
# check total derivative
486495
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
487496
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
488-
489-
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
497+
do_check(flat_input)
490498

491499
# check partial derivatives
492500
old_params_requires_grad = [p.requires_grad for p in params]
@@ -501,14 +509,14 @@ def fn_to_gradcheck(*flat_input_and_params):
501509
p.requires_grad = old
502510
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
503511
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
504-
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
512+
do_check(flat_input)
505513
p.requires_grad = False
506514

507515
for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad):
508516
obj.requires_grad = old
509517
grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors)
510518
flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input)
511-
self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol))
519+
do_check(flat_input)
512520
obj.requires_grad = False
513521

514522
@modules(module_db, allowed_dtypes=[torch.double])

torch/testing/_internal/common_modules.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,9 @@ def __init__(self,
222222
# channels last output
223223
train_and_eval_differ=False, # whether the module has differing behavior between train and eval
224224
module_error_inputs_func=None, # Function to generate module inputs that error
225+
gradcheck_fast_mode=None, # Whether to use the fast implmentation for gradcheck/gradgradcheck.
226+
# When set to None, defers to the default value provided by the wrapper
227+
# function around gradcheck (testing._internal.common_utils.gradcheck)
225228
):
226229
self.module_cls = module_cls
227230
self.module_inputs_func = module_inputs_func
@@ -234,6 +237,7 @@ def __init__(self,
234237
self.module_memformat_affects_out = module_memformat_affects_out
235238
self.train_and_eval_differ = train_and_eval_differ
236239
self.module_error_inputs_func = module_error_inputs_func
240+
self.gradcheck_fast_mode = gradcheck_fast_mode
237241
self.is_lazy = issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin)
238242

239243
def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
@@ -4179,6 +4183,9 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad
41794183
),
41804184
ModuleInfo(torch.nn.Transformer,
41814185
module_inputs_func=module_inputs_torch_nn_Transformer,
4186+
# Inputs are too large to run with slow gradcheck
4187+
# https://github.com/pytorch/pytorch/issues/117140
4188+
gradcheck_fast_mode=True,
41824189
decorators=[
41834190
# Not implemented for SDPA backward derivative
41844191
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_gradgrad',

0 commit comments

Comments
 (0)
0