|
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | import torch._dynamo.test_case
|
6 |
| -from torch._dynamo.testing import CompileCounter, EagerAndRecordGraphs, normalize_gm |
| 6 | +from torch._dynamo.testing import ( |
| 7 | + CompileCounter, |
| 8 | + CompileCounterWithBackend, |
| 9 | + EagerAndRecordGraphs, |
| 10 | + normalize_gm, |
| 11 | +) |
7 | 12 | from torch.testing._internal.common_cuda import TEST_CUDA
|
8 | 13 |
|
9 | 14 |
|
@@ -130,6 +135,184 @@ def fn(x, y):
|
130 | 135 | # No recompile
|
131 | 136 | self.assertEqual(counter.frame_count, 1)
|
132 | 137 |
|
| 138 | + def test_vmapped_autograd_function(self): |
| 139 | + eager = EagerAndRecordGraphs() |
| 140 | + |
| 141 | + class Foo(torch.autograd.Function): |
| 142 | + generate_vmap_rule = True |
| 143 | + |
| 144 | + @staticmethod |
| 145 | + def forward(x): |
| 146 | + return x * 2 |
| 147 | + |
| 148 | + @staticmethod |
| 149 | + def setup_context(ctx, inputs, output): |
| 150 | + pass |
| 151 | + |
| 152 | + @staticmethod |
| 153 | + def backward(ctx, grad): |
| 154 | + return grad * 2 |
| 155 | + |
| 156 | + @torch.compile(backend=eager, fullgraph=True) |
| 157 | + def fn(x): |
| 158 | + return torch.vmap(Foo.apply)(x) |
| 159 | + |
| 160 | + x = torch.randn(2, 3, requires_grad=True) |
| 161 | + self.assertEqual(fn(x), torch.vmap(Foo.apply)(x)) |
| 162 | + |
| 163 | + graph = eager.graphs[0] |
| 164 | + actual = normalize_gm(graph.print_readable(False)) |
| 165 | + self.assertExpectedInline( |
| 166 | + actual, |
| 167 | + """\ |
| 168 | +class GraphModule(torch.nn.Module): |
| 169 | + def forward(self, L_x_: "f32[2, 3]"): |
| 170 | + l_x_ = L_x_ |
| 171 | +
|
| 172 | + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None |
| 173 | +
|
| 174 | + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None |
| 175 | +
|
| 176 | + a: "f32[3]" = torch._C._functorch._add_batch_dim(l_x_, 0, 1); l_x_ = None |
| 177 | +
|
| 178 | + _are_functorch_transforms_active = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active = None |
| 179 | +
|
| 180 | + _are_functorch_transforms_active_1 = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active_1 = None |
| 181 | +
|
| 182 | + child: "f32[3]" = torch._C._functorch.unwrap_if_dead(a); a = None |
| 183 | +
|
| 184 | + _unwrap_batched = torch._C._functorch._unwrap_batched(child, 1); child = None |
| 185 | + getitem: "f32[2, 3]" = _unwrap_batched[0]; _unwrap_batched = None |
| 186 | +
|
| 187 | + pop_dynamic_layer_stack = torch._C._functorch.pop_dynamic_layer_stack() |
| 188 | +
|
| 189 | + _are_functorch_transforms_active_2 = torch._C._are_functorch_transforms_active(); _are_functorch_transforms_active_2 = None |
| 190 | +
|
| 191 | + function_ctx = torch.autograd.function.FunctionCtx(); function_ctx = None |
| 192 | + fwd_body_0 = self.fwd_body_0 |
| 193 | + bwd_body_0 = self.bwd_body_0 |
| 194 | + autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, getitem, args_tensor_mask = [True], non_differentiable_idx = []); fwd_body_0 = bwd_body_0 = getitem = None |
| 195 | + outputs: "f32[2, 3]" = autograd_function_apply[0]; autograd_function_apply = None |
| 196 | +
|
| 197 | + push_dynamic_layer_stack = torch._C._functorch.push_dynamic_layer_stack(pop_dynamic_layer_stack); pop_dynamic_layer_stack = push_dynamic_layer_stack = None |
| 198 | +
|
| 199 | + result: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); outputs = None |
| 200 | +
|
| 201 | + _remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(result, 1, 2, 0); result = None |
| 202 | +
|
| 203 | + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None |
| 204 | + return (_remove_batch_dim,) |
| 205 | +
|
| 206 | + class fwd_body_0(torch.nn.Module): |
| 207 | + def forward(self, function_ctx : torch.autograd.function.Function, getitem: "f32[2, 3]"): |
| 208 | + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None |
| 209 | + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None |
| 210 | +
|
| 211 | + _add_batch_dim: "f32[3]" = torch._C._functorch._add_batch_dim(getitem, 0, 1) |
| 212 | +
|
| 213 | + batched_outputs: "f32[3]" = _add_batch_dim * 2; _add_batch_dim = None |
| 214 | +
|
| 215 | + _unwrap_batched = torch._C._functorch._unwrap_batched(batched_outputs, 1); batched_outputs = None |
| 216 | + outputs: "f32[2, 3]" = _unwrap_batched[0] |
| 217 | + getitem_2 = _unwrap_batched[1]; _unwrap_batched = getitem_2 = None |
| 218 | +
|
| 219 | + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None |
| 220 | + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_1 = None |
| 221 | +
|
| 222 | + inp: "f32[3]" = torch._C._functorch._add_batch_dim(getitem, 0, 1); getitem = inp = None |
| 223 | + _add_batch_dim_2: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); _add_batch_dim_2 = None |
| 224 | +
|
| 225 | + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None |
| 226 | + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None |
| 227 | + return ((outputs, 0), []) |
| 228 | +
|
| 229 | + class bwd_body_0(torch.nn.Module): |
| 230 | + def forward(self, function_ctx : torch.autograd.function.Function, outputs: "f32[2, 3]", const_unused : int): |
| 231 | + _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None |
| 232 | + _vmap_increment_nesting = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting = None |
| 233 | +
|
| 234 | + _add_batch_dim: "f32[3]" = torch._C._functorch._add_batch_dim(outputs, 0, 1); outputs = None |
| 235 | +
|
| 236 | + batched_outputs: "f32[3]" = _add_batch_dim * 2; _add_batch_dim = None |
| 237 | +
|
| 238 | + _unwrap_batched = torch._C._functorch._unwrap_batched(batched_outputs, 1); batched_outputs = None |
| 239 | + grad_ins: "f32[2, 3]" = _unwrap_batched[0] |
| 240 | + getitem_1 = _unwrap_batched[1]; _unwrap_batched = getitem_1 = None |
| 241 | +
|
| 242 | + _vmap_decrement_nesting = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting = None |
| 243 | +
|
| 244 | + lazy_load_decompositions = torch._functorch.vmap.lazy_load_decompositions(); lazy_load_decompositions = None |
| 245 | +
|
| 246 | + _vmap_increment_nesting_1 = torch._C._functorch._vmap_increment_nesting(2, 'error'); _vmap_increment_nesting_1 = None |
| 247 | +
|
| 248 | + _add_batch_dim_1: "f32[3]" = torch._C._functorch._add_batch_dim(grad_ins, 0, 1); grad_ins = None |
| 249 | +
|
| 250 | + batched_outputs_1: "f32[3]" = _add_batch_dim_1.sum_to_size((3,)); _add_batch_dim_1 = None |
| 251 | +
|
| 252 | + _remove_batch_dim: "f32[2, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs_1, 1, 2, 0); batched_outputs_1 = None |
| 253 | +
|
| 254 | + _vmap_decrement_nesting_1 = torch._C._functorch._vmap_decrement_nesting(); _vmap_decrement_nesting_1 = None |
| 255 | + _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None |
| 256 | + return (_remove_batch_dim,) |
| 257 | +""", # NOQA: B950 |
| 258 | + ) |
| 259 | + |
| 260 | + def test_vmapped_autograd_function_fwd_and_bwd(self): |
| 261 | + cnt = CompileCounterWithBackend("aot_eager") |
| 262 | + |
| 263 | + class LinearFunction(torch.autograd.Function): |
| 264 | + generate_vmap_rule = True |
| 265 | + |
| 266 | + @staticmethod |
| 267 | + def forward(input, weight, bias): |
| 268 | + output = input.mm(weight.t()) |
| 269 | + if bias is not None: |
| 270 | + output += bias.unsqueeze(0).expand_as(output) |
| 271 | + return output |
| 272 | + |
| 273 | + @staticmethod |
| 274 | + def setup_context(ctx, inputs, output): |
| 275 | + input, weight, bias = inputs |
| 276 | + ctx.save_for_backward(input, weight, bias) |
| 277 | + |
| 278 | + @staticmethod |
| 279 | + def backward(ctx, grad_output): |
| 280 | + input, weight, bias = ctx.saved_tensors |
| 281 | + grad_input = grad_weight = grad_bias = None |
| 282 | + if ctx.needs_input_grad[0]: |
| 283 | + grad_input = grad_output.mm(weight) |
| 284 | + if ctx.needs_input_grad[1]: |
| 285 | + grad_weight = grad_output.t().mm(input) |
| 286 | + if bias is not None and ctx.needs_input_grad[2]: |
| 287 | + grad_bias = grad_output.sum(0) |
| 288 | + |
| 289 | + return grad_input, grad_weight, grad_bias |
| 290 | + |
| 291 | + def fn(input, weight, bias=None): |
| 292 | + return torch.vmap(LinearFunction.apply)(input, weight, bias) |
| 293 | + |
| 294 | + input1 = torch.randn(4, 2, 2, dtype=torch.double, requires_grad=True) |
| 295 | + input2 = input1.clone().detach().requires_grad_(True) |
| 296 | + weight1 = torch.randn(4, 3, 2, dtype=torch.double, requires_grad=True) |
| 297 | + weight2 = weight1.clone().detach().requires_grad_(True) |
| 298 | + bias1 = torch.randn(4, 3, dtype=torch.double, requires_grad=True) |
| 299 | + bias2 = bias1.clone().detach().requires_grad_(True) |
| 300 | + |
| 301 | + compiled_fn = torch.compile(backend=cnt, fullgraph=True)(fn) |
| 302 | + |
| 303 | + output1 = fn(input1, weight1, bias1) |
| 304 | + output1.sum().backward() |
| 305 | + |
| 306 | + output2 = compiled_fn(input2, weight2, bias2) |
| 307 | + output2.sum().backward() |
| 308 | + |
| 309 | + self.assertEqual(output1, output2) |
| 310 | + self.assertEqual(input1.grad, input2.grad) |
| 311 | + self.assertEqual(weight1.grad, weight2.grad) |
| 312 | + self.assertEqual(bias1.grad, bias2.grad) |
| 313 | + self.assertEqual(cnt.frame_count, 1) |
| 314 | + self.assertEqual(cnt.op_count, 25) |
| 315 | + |
133 | 316 |
|
134 | 317 | if __name__ == "__main__":
|
135 | 318 | from torch._dynamo.test_case import run_tests
|
|
0 commit comments