@@ -7518,33 +7518,37 @@ def forward(self, dict_input):
7518
7518
def test_checkpointing_preserves_torch_function_mode_stack(self):
7519
7519
log = []
7520
7520
7521
- class Func1(TorchFunctionMode):
7522
- def __torch_function__(self, func, types, args, kwargs=None):
7523
- kwargs = {} if kwargs is None else kwargs
7524
- log.append("func1")
7525
- return func(*args, **kwargs)
7521
+ def get_mode_class(n):
7522
+ class Func(TorchFunctionMode):
7523
+ def __torch_function__(self, func, types, args, kwargs=None):
7524
+ kwargs = {} if kwargs is None else kwargs
7525
+ log.append(f"mode{n}")
7526
+ return func(*args, **kwargs)
7526
7527
7527
- class Func2(TorchFunctionMode):
7528
- def __torch_function__(self, func, types, args, kwargs=None):
7529
- kwargs = {} if kwargs is None else kwargs
7530
- log.append("func2" )
7531
- return func(*args, **kwargs )
7528
+ return Func
7529
+
7530
+ Mode1 = get_mode_class(1)
7531
+ Mode2 = get_mode_class(2 )
7532
+ Mode3 = get_mode_class(3 )
7532
7533
7533
7534
def func(x):
7534
7535
return x.sin().cos()
7535
7536
7536
- with Func1():
7537
- with Func2():
7537
+ def context_fn():
7538
+ return Mode3(), Mode3()
7539
+
7540
+ with Mode1():
7541
+ with Mode2():
7538
7542
a = torch.tensor(1., requires_grad=True)
7539
7543
7540
7544
log = []
7541
- out = checkpoint(func, a, use_reentrant=False)
7542
- self.assertTrue(log[0] == "func2" and log[1] == "func1")
7545
+ out = checkpoint(func, a, use_reentrant=False, context_fn=context_fn)
7546
+ self.assertTrue(log[-3] == "mode3" and log[-2] == "mode2" and log[-1] == "mode1")
7547
+
7543
7548
7544
7549
log = []
7545
7550
out.backward()
7546
-
7547
- self.assertTrue(log[0] == "func2" and log[1] == "func1")
7551
+ self.assertTrue(log[-3] == "mode3" and log[-2] == "mode2" and log[1] == "mode1")
7548
7552
7549
7553
def test_callback_adds_callback(self):
7550
7554
called = [0]
0 commit comments