8000 Internal uses of `torch.cuda.amp.autocast` raise FutureWarnings · Issue #130659 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Internal uses of torch.cuda.amp.autocast raise FutureWarnings #130659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
awaelchli opened this issue Jul 13, 2024 · 10 comments
Closed

Internal uses of torch.cuda.amp.autocast raise FutureWarnings #130659

awaelchli opened this issue Jul 13, 2024 · 10 comments
Labels
module: amp (automated mixed precision) autocast triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@awaelchli
Copy link
Contributor
awaelchli commented Jul 13, 2024

🐛 Describe the bug

PyTorch 2.4 deprecated the use of torch.cuda.amp.autocast in favor of torch.amp.autocast("cuda", ...), but this change has missed updating internal uses in PyTorch. For example in DP here:

with torch.cuda.device(device), torch.cuda.stream(stream), autocast(
enabled=autocast_enabled

import torch

model = torch.nn.Linear(10, 10).cuda()
model = torch.nn.DataParallel(model, device_ids=[0, 1])
output = model(torch.randn(20, 10).cuda())

Produces:

/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py:79: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled):
/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/linear.py:117: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at ../aten/src/ATen/cuda/CublasHandlePool.cpp:135.)
  return F.linear(input, self.weight, self.bias)

Since these are caused internally, the user will see the warning but not be able to do anything with it.

If desired, I can send a PR with a fix for this :)

Versions

PyTorch version: 2.4.0+cu121

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5

@albanD albanD added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: amp (automated mixed precision) autocast and removed triage review labels Jul 15, 2024
@albanD albanD added this to the 2.4.1 milestone Jul 15, 2024
@albanD
Copy link
Collaborator
albanD commented Jul 15, 2024

cc @guangyey could you please help fix this?

@guangyey
Copy link
Collaborator

cc @guangyey could you please help fix this?
I see @awaelchli already filed a PR #130660 to fix this. I give a comment to fix the lint.

mlazos pushed a commit that referenced this issue Jul 18, 2024
Fixes #130659

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Pull Request resolved: #130660
Approved by: https://github.com/guangyey, https://github.com/fegin, https://github.com/albanD
DiweiSun pushed a commit to DiweiSun/pytorch that referenced this issue Jul 22, 2024
Fixes pytorch#130659

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Pull Request resolved: pytorch#130660
Approved by: https://github.com/guangyey, https://github.com/fegin, https://github.com/albanD
xuhancn pushed a commit to xuhancn/pytorch that referenced this issue Jul 25, 2024
Fixes pytorch#130659

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Pull Request resolved: pytorch#130660
Approved by: https://github.com/guangyey, https://github.com/fegin, https://github.com/albanD
@function2-llx
Copy link
function2-llx commented Jul 30, 2024

Also found a usage of torch.cpu.amp.autocast here in PyTorch 2.4.

What's worse, the warning for this line produces everytime during backward when using checkpoint, not only the first time.

@guangyey
Copy link
Collaborator
guangyey commented Jul 31, 2024

@function2-llx I see this fix targets release 2.4.1, please refer to [this comment].(#128436 (comment))
By the way, I fix the other deprecation warning in #132207

pytorchbot pushed a commit that referenced this issue Aug 15, 2024
Fixes #130659

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Pull Request resolved: #130660
Approved by: https://github.com/guangyey, https://github.com/fegin, https://github.com/albanD

(cherry picked from commit bb62e9d)
atalman pushed a commit to atalman/pytorch that referenced this issue Aug 20, 2024
Fixes pytorch#130659

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Pull Request resolved: pytorch#130660
Approved by: https://github.com/guangyey, https://github.com/fegin, https://github.com/albanD
atalman added a commit that referenced this issue Aug 21, 2024
Fixes #130659

Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com>
Pull Request resolved: #130660
Approved by: https://github.com/guangyey, https://github.com/fegin, https://github.com/albanD

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
@naga24
Copy link
naga24 commented Aug 22, 2024

Also found a usage of torch.cpu.amp.autocast here in PyTorch 2.4.

What's worse, the warning for this line produces everytime during backward when using checkpoint, not only the first time.

You can downgrade to torch==2.3.1 and torchvision==0.18.1, warning stops coming

@guangyey
Copy link
Collaborator

@function2-llx, you can try the way @naga24 mentioned or get the fix in the upcoming 2.4.1 release.

@function2-llx
Copy link

Thanks for the suggestions!
There is also a workaround that disables the warning by passing arguments to the interpreter, something like this will work: python -W ignore::FutureWarning:torch.utils.checkpoint:1399.

@tgilbert09
Copy link

Also found in engine.py, line 30. This FutureWarning appears on the tutorial website TorchVision Object Detection Finetuning Tutorial. Also mentioned in pytorch/tutorials#3007.

@kit1980
Copy link
Contributor
kit1980 commented Aug 30, 2024

Validated with torch==2.4.1 RC.

FYI the original repro modified to run on a single GPU machine:

import torch

model = torch.nn.Linear(10, 10).cuda()
model = torch.nn.DataParallel(model, device_ids=[0, 0])
output = model(torch.randn(20, 10).cuda())

@Sola-AIGithub
Copy link

Hello, there. I check the model dtype then use 「with torch.autocast("cuda", torch.float32):」 instead 「with torch.cuda.amp.autocast(autocast):」 it works and no Future Warning, FYI.

p.s.torch.version == 2.4.1 , yolov5 model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: amp (automated mixed precision) autocast triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants
0