8000 Add Swin Transformer Example by sumantro93 · Pull Request #1346 · pytorch/examples · GitHub
[go: up one dir, main page]

Skip to content

Add Swin Transformer Example #1346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

Conversation

sumantro93
Copy link

Solves part of #1131

@msaroufim

Copy link
netlify bot commented May 16, 2025

Deploy Preview for pytorch-examples-preview canceled.

Name Link
🔨 Latest commit 5478049
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-examples-preview/deploys/682c1e221868830008b29270

Comment on lines 175 to 176
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to use the accelerator API

Suggested change
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
use_accel = torch.accelerator.is_available()
device = torch.accelerator.current_accelerator() if use_accel else torch.device("cpu")
print(f"Using device: {device}")

Install dependencies:

```bash
pip install torch torchvision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use requirements.txt to be consistent with other examples

Testing is done automatically after each epoch. To only test, run with:

```bash 8000
python swin_transformer.py --epochs 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this should be 1 epoch

@sumantro93
Copy link
Author

@jafraustro , Thanks for the review. I've updated my PR. Please have a look and lemme know if something else is needed. :))

Copy link
Contributor
@jafraustro jafraustro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sumantro93
Copy link
Author

@jafraustro can this be merged ? or does it require something else from my end?

@jafraustro
Copy link
Contributor

Hi @msaroufim, could you give a look to this PR?

@@ -0,0 +1,2 @@
torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
torch
torch>=2.6

Due to usage of torch.accelerator.

@@ -192,6 +196,7 @@ function stop() {
word_language_model/model.pt \
gcn/cora/ \
gat/cora/ || error "couldn't clean up some files"
swin_transformer/swin_cifar10.pt || error "couldn't clean up some files"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be failing with "command swin_transformer/swin_cifar10.pt" not found. You need to add to the list. I can't suggest the change since it touches non-modified cmdline, but code should be:

gat/cora/ \
swin_transformer/swin_cifar10.pt || error "couldn't clean up some files"

I.e. line break on gat/cora/ line.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, i ran the examples and yes I had to also change the swin to swin_transformer in Line 228

@@ -0,0 +1,207 @@
from __future__ import print_function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems script does not actually use print_function. Can this be dropped?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 5478049

torch.save(model.state_dict(), "swin_cifar10.pt")

if __name__ == '__main__':
main()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
main()
main()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 5478049

### Save the model

```bash
python swin_transformer.py --save-model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is --save-model really needed? should not the trained model be always saved?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done! fixes in 5478049

@sumantro93
Copy link
Author

@dvrogozh Thanks for your reviwes, I've made the changes :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0