Add accelerate API support for Word Language Model example#1345
Add accelerate API support for Word Language Model example#1345msaroufim merged 8 commits intopytorch:mainfrom
Conversation
✅ Deploy Preview for pytorch-examples-preview canceled.
|
word_language_model/main.py
Outdated
| # this makes them a continuous chunk, and will speed up forward pass | ||
| # Currently, only rnn model supports flatten_parameters function. | ||
| if args.model in ['RNN_TANH', 'RNN_RELU', 'LSTM', 'GRU']: | ||
| if args.model in ['RNN_TANH', 'RNN_RELU', 'LSTM', 'GRU'] and device.type == 'cuda': |
There was a problem hiding this comment.
what was the error you're getting?
There was a problem hiding this comment.
Seems to be an overlook from my part. This was needed when trying a safe approach of only loading the weights but apparently it is no longer needed. I will remove it to prevent any unwanted changes
word_language_model/main.py
Outdated
| # Load the best saved model. | ||
| with open(args.save, 'rb') as f: | ||
| model = torch.load(f) | ||
| torch.load(f, weights_only=False) |
There was a problem hiding this comment.
Can you, please, extract this change to separate PR? It also needs an update for required torch version:
There was a problem hiding this comment.
If I extract the change and update the requirements to 2.7 it won't work, this change allows the example to run with the simplest code change, since leaving it as it was fails to work
There was a problem hiding this comment.
In PyTorch 2.6, the default value for weights_only was set to True, and PyTorch 2.7 introduced support for the accelerator API.
In this pull request, we can integrate the use of the accelerator API in this PR. Meanwhile, we will address the update for saving and loading models using state_dict in a separate pull request.
There was a problem hiding this comment.
PyTorch 2.7 introduced support for the accelerator API. <...>In this pull request, we can integrate the use of the accelerator API in this PR.
From 2.6 actually. See https://docs.pytorch.org/docs/2.6/accelerator.html#module-torch.accelerator.
To integrate torch.accelerator we must update the requirement for torch to be >=2.6. Otherwise tests will simply fail. I suspect that you did not actually run the modified run_python_examples.sh.
If I extract the change and update the requirements to 2.7 it won't work
I believe you are doing changes in the wrong order. First, update requirement to be able to use latest pytorch and fix issues which appear. Next, as a second step, introduce new APIs.
There was a problem hiding this comment.
I did run the modified run_python_examples.sh but maybe I am doing this in the wrong order. So the suggestion here is to first update requirements and fix the issues in a separate PR, close this one and create a new one for the new API?
There was a problem hiding this comment.
First, we need to run the example with latest PyTorch and fix any issue in a separate PR.
Thanks for the feedback @dvrogozh.
There was a problem hiding this comment.
the suggestion here is to first update requirements and fix the issues in a separate PR, close this one and create
Yes, but you don't need to close this PR. Just mark it as a draft while working on the update requirements PR.
There was a problem hiding this comment.
Here is a PR to update torch version requirement as I would do it:
word_language_model/README.md
Outdated
| python main.py --cuda --epochs 6 --tied # Train a tied LSTM on Wikitext-2 with CUDA. | ||
| python main.py --cuda --tied # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs. | ||
| python main.py --cuda --epochs 6 --model Transformer --lr 5 | ||
| python main.py --accel --epochs 6 # Train a LSTM on Wikitext-2 with CUDA. |
There was a problem hiding this comment.
with CUDA
I suggest to drop this from example command line and maybe add a note that example supports running on acceleration devices and list which were tried (CUDA, MPS, XPU).
Co-authored-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Co-authored-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
|
Error looks legit |
Yep, it's legit: @framoncg, you've updated cmdline arguments replacing examples/run_python_examples.sh Lines 156 to 163 in 6f61614 |
|
hi @msaroufim I updated the flags on the CI script, tested it and it works, can you take a look? |
Refactor Word Language Model example to utilize torch.accelerator API torch.accelerator API allows to abstract some of the accelerator specifics in the user scripts. By leveraging this API, the code becomes more adaptable to various hardware accelerators.
Updated word_language_model/main.py with accelerator flag
Updated word_language_model/generate.py with accelerator flag
Updated README to match word_language_model/main.py flags
Updated run_python_examples.sh to add new accelerator flag
CC: @msaroufim, @malfet, @dvrogozh