8000 [multigraph] use backend specializations in compile_and_call_fx_graph by bobrenjc93 · Pull Request #152601 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[multigraph] use backend specializations in compile_and_call_fx_graph #152601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from

Conversation

bobrenjc93
Copy link
Contributor
@bobrenjc93 bobrenjc93 commented May 1, 2025

Stack from ghstack (oldest at bottom):

The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs.

There's really two parts of this work:

The frontend changes:

  1. we introduce an optional kwarg backend_specializations to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

The backend changes (this PR):

  1. We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in /_dynamo/variables/builder.py
  2. After we are done dynamo tracing, we invoke call_user_compiler N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.
  3. When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

495269647_576053105510082_9189856138964956774_n

Copy link
pytorch-bot bot commented May 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152601

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures

As of commit ee88cb9 with merge base 8f54e56 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@bobrenjc93 bobrenjc93 mentioned this pull request May 3, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
@bobrenjc93 bobrenjc93 mentioned this pull request May 3, 2025
bobrenjc93 added 8 commits May 3, 2025 00:07
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 4, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 4, 2025
…ll_fx_graph"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.
3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 5, 2025
…ll_fx_graph"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.
3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 5, 2025
…ll_fx_graph"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.
3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 5, 2025
…ll_fx_graph"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.
3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 5, 2025
…ll_fx_graph"


The goal of this multigraph work is to enable a compiled region that has a single dynamo trace but multiple backend specializations. This work was inspired by vLLM who does this in a somewhat hacky way where they use a custom backend to capture a dynamo graph and then manually invoke compile_fx multiple times to get specialized graphs. 

There's really two parts of this work:

**The frontend changes:**
1) we introduce an optional kwarg `backend_specializations` to mark_dynamic that takes in a list of specializations. I debated other methods including specifying specializations via decorators, but ultimately decided this approach was more harmonious. The big issue with decorators is the difficulty of composing well with the rest of the torch.compile ecosystem including graph breaks, lazy initialization of variable trackers and symbolic variables, etc.

**The backend changes (this PR):**
1) We capture the backend_specialization specified in the mark_dynamic API into a SymbolicContext. See changes in `/_dynamo/variables/builder.py`
2) After we are done dynamo tracing, we invoke `call_user_compiler` N + 1 times for N specializations and 1 generic graph. Under the hood this will call compile_fx, which composes nicely with both Async Compile and AOTAutogradCache.
3) When we have specializations, we install a specialized dispatch function that checks each specialization and dispatches to the first one that matches. If none of the specializations match, we dispatch to the generic graph. I decided to do this over returning N different GuardedCodes since 1) it doesn't pollute the dynamo cache (eg. if you have 8 specializations, you would hit the cache limit) 2) it naturally incorporates the hierarchical lattice structure of the guards since the specializations are always necessarily stricter than the generic region's guards.

I benchmarked this PR stack with #152596 and found around a 50% reduction when dispatching to the specialized regions:

![495269647_576053105510082_9189856138964956774_n](https://github.com/user-attachments/assets/66030fed-d62e-4d87-940f-aa13c99b1a73)

[ghstack-poisoned]
bobrenjc93 added a commit that referenced this pull request May 5, 2025
@bobrenjc93 bobrenjc93 requested review from zou3519 and eellison May 5, 2025 12:45
@bobrenjc93 bobrenjc93 marked this pull request as ready for review May 5, 2025 12:46
Comment on lines +293 to +295
raise RuntimeError(
"Backend specializations are only supported for contiguous tensors."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's going on here?

Comment on lines +1506 to +1527
for specialization in old_fake_mode.shape_env.backend_specializations:
source_index = sources.index(specialization.source)
check_fn_source = inspect.getsource(specialization.check_fn).strip()
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
specialization.check_fn,
[check_fn_source],
)

log.debug(
"Compiling backend specialized graph with specialization=%s",
check_fn_source,
)

specialized_compiles.append(
(
functools.partial(
lambda idx, args, check_fn=check_fn: check_fn(
args[idx]
),
source_index,
),
self.call_user_compiler(gm, specialization=specialization),
Copy link
Contributor
@zou3519 zou3519 May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This calls the backend compiler with the (tensor_args,) for this graph and the specialization argument, right?

I'm not sure this is the right design. The (tensor_args,) don't have the same shape as the specialization -- will that be a problem?

An alternative design is that there is some lazy dispatching layer right after Dynamo but before AOTAutograd. Let's say the user calls the following for the first time:

# A
x = torch.randn(3)
mark_dynamic(x, 0, backend_specializations=[1, 2])
torch.compile(f)(x)

Then this traces out a graph from Dynamo with dynamic shapes.

Then, on future calls to torch.compile:

# B
y = torch.randn(1)
torch.compile(f)(y)
  • On seeing a specialized shape for the first time: this skips Dynamo but directly forwards the args (y,) to the backend to compile a graph
# C
z = torch.randn(1)
torch.compile(f)(z)
  • On seeing a specialized shape again: this pulls up the graph the backend compiled for said shape.

One way to implement this is:

  • Let's think about the Dynamo cache as a mapping from guards to a callable
  • After (A), there is a guard for each of the specializations: {"batch_size==1": call_backend_compile(), "batch_size==2": call_backend_compile(), "batch_size==anything_else": compiled_artifact}
  • (B) hits the call_backend_compile() function, which will compile a backend function and replace the Dynamo cache entry with {"batch_size==1": compiled_artifact}
  • Future hits to this guard (e.g. C) will just hit the compiled artifact.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benefit of the alternative lazy design is that the backend doesn't need to work hard to figure out how to do the specialization: it's almost like calling regular torch.compile again, except it is able to skip Dynamo.

One side effect is that we don't have to impose constraints on the strides (this PR needs to do that because it needs to figure out how to create a FakeTensor, right?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense. cc @anijain2305 for thoughts as well

Copy link
Contributor
@anijain2305 anijain2305 May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few details that we need to think about

  1. We will have multiple cache entries per code object here. For example, our cache size limit is 8, but the specialization here will require us to raise cache size limit for certain code objects.

  2. Dynamo cache as a mapping from guards to a callable - This is true, but there is a subtle difference. Dynamo does guards to bytecode mapping. This bytecode contains the call to the compiled_graph (not Fx graph, a compiled graph). So in this design, we will have to figure out how to (1) stash the bytecode, and (2) stash the Dynamo graph.

  3. Overwriting cache entry is also questionable.

Maybe we have the bytecode that calls the backend_compile. And then the backend_compile internally checks if there is a compiled code. If yes, then run the compiled code, otherwise run the AOT + Inductor compilation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anijain2305 thoughts on #153449 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the issue with the current implementation? Its not bad. It gives the hierarchical feel, which kind of makes sense in this case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider the following code:

x = torch.randn(3)
mark_dynamic(x, 0, backend_specializations=[1, 2])
torch.compile(f)(x)
x = torch.randn(1)
torch.compile(f)(x)
x = torch.randn(2)
torch.compile(f)(x)

On the first torch.compile call, we will attempt to compile all of the backend specializations. That torch.compile call only has one set of sample inputs (of shape [3]). The problems I'm worried about is:
a) Compile time will be slow up front. On the first torch.compile it looks like we call the backend compiler three times.
b) Because there are no real tensor inputs of shape [1] and shape [2], we need to guess at those tensors and assume that they're contiguous. This doesn't seem very good

The lazier design (#153449) solves this by (a) deferring compilation of shape [1] and shape [2] until we actually see inputs of those shapes and (b) if the strides change then it's a recompile

Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with @zou3519 comments

Comment on lines +10498 to +10500
backend_specializations=[
(16, lambda x0: x0 == 16),
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an api flow for when you want to specify conditions on multiple vars?

E.g.

lambda x, y: x == 1 and y == 1 
lambda: x, y: x % 16 and y % 16

You dont necessarily want to specialize on x == 1 and y % 16, which I assume would fall out of the pairwise specializations

Copy link
Contributor Author
@bobrenjc93 bobrenjc93 May 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at the moment. vLLM actually only has one symbolic variable (https://www.anyscale.com/blog/continuous-batching-llm-inference) so we don't need to worry about that for our first customer. That being said, I'm happy to bikeshed what a better multi-var API may look like during composability.

Comment on lines +10508 to +10510
dynamic_specialized = do_bench(
lambda: inductor_matmul(dynamic_specialized_a, b)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check the output code

@@ -1 +1 @@
14256e6040d9e14698a877924456cdd92bfcd01d
8eeef7f5b5363e9f35576184659226cc082311d6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intentional?

Comment on lines +1506 to +1527
for specialization in old_fake_mode.shape_env.backend_specializations:
source_index = sources.index(specialization.source)
check_fn_source = inspect.getsource(specialization.check_fn).strip()
check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
specialization.check_fn,
[check_fn_source],
)

log.debug(
"Compiling backend specialized graph with specialization=%s",
check_fn_source,
)

specialized_compiles.append(
(
functools.partial(
lambda idx, args, check_fn=check_fn: check_fn(
args[idx]
),
source_index,
),
self.call_user_compiler(gm, specialization=specialization),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense. cc @anijain2305 for thoughts as well

@bobrenjc93 bobrenjc93 requested a review from anijain2305 May 12, 2025 23:12

m = 16
k = 1280
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, The function is decorated with requires_gpu which means the case will run on GPUs like cuda/xpu, but the hard code cuda here will fail on other GPUs like XPU.

Suggested change
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)

m = 16
k = 1280
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
dynamic_specialized_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
dynamic_specialized_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)

k = 1280
dynamic_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
dynamic_specialized_a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
b = torch.randn(k, m, device="cuda", dtype=torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
b = torch.randn(k, m, device="cuda", dtype=torch.bfloat16)
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16)

@bobrenjc93
Copy link
Contributor Author

Abandoning in favor of lazy approach: #153449

@bobrenjc93 bobrenjc93 closed this May 13, 2025
@github-actions github-actions bot deleted the gh/bobrenjc93/331/head branch June 15, 2025 02:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0