8000 ephemeral GPU offload support by kallewoof · Pull Request #1857 · huggingface/peft · GitHub
[go: up one dir, main page]

Skip to content

ephemeral GPU offload support #1857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 2, 2024

Conversation

kallewoof
Copy link
Contributor
@kallewoof kallewoof commented Jun 13, 2024

Adds the concept of ephemeral GPU offloading, i.e. where data in compute intense operations is copied onto the GPU before the operation is performed, after which the result is put back on CPU memory.

This PR adds support in the dora initialization code, but the approach can be applied in a number of places: when the size of the data compared to the time to perform the operation on CPU memory is heavily time dominant, using ephemeral transfers has a fairly small VRAM overhead (depending on the size of the model/adapter) with orders of magnitude speed-up in certain operations.

For example, a Llama3-8B DoRA adapter with r=64 would put an overhead of 2 x (64 x 4096 x 2 + 4096 x 4096) bytes (assuming fp16), i.e. 33 MB or so. A Llama3-70B adapter with r=32 would have 2 x (32 x 8192 x 2 + 8192 x 8192) bytes =130 MB.

By making use of ephemeral GPU offloading, more efficient juggling of data between GPU and CPU may become possible, i.e. where instead of always loading as much as we can onto the GPU and then endure the CPU slowness for whatever happens to not fit in there, we intentionally leave a (modest) chunk of VRAM for optimizations like these, and the end result is a much (MUCH) faster experience.

See examples/ephemeral_gpu_offload/load_with_dora.py for an example script demonstrating this feature. Example output using the defaults of that script are:

Example outputs:
$ python load_with_dora.py
-------------------- Loading model --------------------
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.03s/it]
-------------------- Loading PeftModel --------------------
-------------------- Done --------------------
Model loading time: 4.83s
PeftModel loading time: 28.14s
Use ephemeral GPU offloading: False

(Note: if this was the first time you ran the script, or if your cache was cleared, the times shown above are invalid, due to the time taken to download the model and DoRA files. Just re-run the script in this case.)

$ python load_with_dora.py --ephemeral_gpu_offload
-------------------- Loading model --------------------
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.11it/s]
-------------------- Loading PeftModel --------------------
-------------------- Done --------------------
Model loading time: 4.28s
PeftModel loading time: 16.59s
Use ephemeral GPU offloading: True

(Note: if this was the first time you ran the script, or if your cache was cleared, the times shown above are invalid, due to the time taken to download the model and DoRA files. Just re-run the script in this case.)

In this case, using ephemeral GPU offloading, we finished in 2/3rd of the time taken doing it purely on CPU.

I verified that the resulting merged models were identical using cmp on the model-000*.safetensors files for the variants using and not using ephemeral transfers.

@BenjaminBossan
Copy link
Member

Thanks for creating this draft PR.

when the size of the data compared to the time to perform the operation on CPU memory is heavily time dominant, using ephemeral transfers has a fairly small VRAM overhead (depending on the size of the model/adapter) with orders of magnitude speed-up in certain operations.

I have to admit that I didn't fully understand this yet. Could you please explain this in more details or give a reference?

From looking at the code, my understanding is that if lora_A or lora_B is on CPU, it is ensured that the DoRA weight norm is also kept placed on CPU. Where is this part then happening? Is it happening indirectly?

we intentionally leave a (modest) chunk of VRAM for optimizations like these, and the end result is a much (MUCH) faster experience.

Ephemeral transfers result in
the data involved in intense operations being momentarily copied over to the GPU, and the results copied
back to CPU.

Regarding performance:

Using a patched Axolotl installation, I ran a merge of a Llama3-8B-DoRA adapter with r=64, where I capped the two cards to only use 7 GB of VRAM each for model loading (i.e. 14 GB available, where the model alone needs something like 16 GB).

Do you have a script to reproduce the experiment? What exactly is it measuring: loading, merging, saving? How would this affect inference and training?

I am not happy with keeping the flag in LoraConfig, but I'm not sure how to propagate it to the right place, especially if we start using this in other places. It feels like this should be a runtime option, that does not persist, i.e. is not saved in configs and re-used inadvertently when people load a LoRA.

As is, this is very LoRA-specific, as it only affects DoRA (there should be a check that this should not be enabled if use_dora is False). I agree that putting it into the config might not seem like the right place, but I don't really see anywhere else to put it. Extra code could be added that this flag is not persisted as True when the config is saved.

@kallewoof
Copy link
Contributor Author
kallewoof commented Jun 14, 2024

when the size of the data compared to the time to perform the operation on CPU memory is heavily time dominant, using ephemeral transfers has a fairly small VRAM overhead (depending on the size of the model/adapter) with orders of magnitude speed-up in certain operations.

I have to admit that I didn't fully understand this yet. Could you please explain this in more details or give a reference?

Sorry, what I meant was the following rule of thumb:

  1. The smaller the data, the better
  2. The bigger the time difference between performing the operation(s) on CPU vs GPU, the better.

Anywhere in the code where these two conditions apply, we should do what I refer to above as ephemeral transfers, which is simply to copy the values to the GPU, perform the computation, and pull the results back into CPU memory. I am beginning to regret even giving it a name, but I can't think of a succinct way to describe it.

From looking at the code, my understanding is that if lora_A or lora_B is on CPU, it is ensured that the DoRA weight norm is also kept placed on CPU. Where is this part then happening? Is it happening indirectly?

When using ephemeral transfers we are putting the lora layers on the GPU. Because of this, the resulting DoRA weight norm ends up on the GPU. We don't want that, since it will take up GPU space, and the whole idea is to be ephemeral. As such, we pass in a put it on the CPU flag for the case when the incoming lora A/B's are on the CPU initially.

There are three cases:

  1. Lora A/B are on the CPU, and we are not using ephemeral transfers: they remain on CPU throughout the operations, and the weight norm comes back on CPU.
  2. Lora A/B are on the CPU, and we are using ephemeral transfers: copies of them are put on the GPU, the weight norm is calculated (ending up on the GPU), and is then explicitly moved to CPU, which is expected and desired.
  3. Lora A/B are on the GPU, and the weight norm come back on the GPU as well. This is expected and desired. Use of ephemeral transfers is irrelevant here.

Regarding performance:

Using a patched Axolotl installation, I ran a merge of a Llama3-8B-DoRA adapter with r=64, where I capped the two cards to only use 7 GB of VRAM each for model loading (i.e. 14 GB available, where the model alone needs something like 16 GB).

Do you have a script to reproduce the experiment?

This is the relevant code inside Axolotl (patched to add ephemeral_transfers):

        model_kwargs: Any = {}
        if cfg.lora_on_cpu:
            model_kwargs["max_memory"] = {"cpu": "256GiB"}
            model_kwargs["device_map"] = {"": "cpu"}
        if cfg.ephemeral_transfers:
            model_kwargs["ephemeral_transfers"] = cfg.ephemeral_transfers
        model = PeftModel.from_pretrained(
            model,
            cfg.lora_model_dir,
            is_trainable=(not inference),
            **model_kwargs,
        )

The lora_on_cpu part can be ignored, and was ununsed in my tests.

What exactly is it measuring: loading, merging, saving?

In this case, loading. This is part of a merge operation though.

How would this affect inference and training?

I'm not sure. I would assume you kept ephemeral transfers off, but it's possible that this might allow you to run with higher parameter values (LoRA rank etc) during training, by keeping things on the CPU that wasn't needed all the time. I'm not sure.

I am not happy with keeping the flag in LoraConfig, but I'm not sure how to propagate it to the right place, especially if we start using this in other places. It feels like this should be a runtime option, that does not persist, i.e. is not saved in configs and re-used inadvertently when people load a LoRA.

As is, this is very LoRA-specific, as it only affects DoRA (there should be a check that this should not be enabled if use_dora is False). I agree that putting it into the config might not seem like the right place, but I don't really see anywhere else to put it. Extra code could be added that this flag is not persisted as True when the config is saved.

If we had a good way to mark something as "do not store in final adapter_config.json file", that would be ideal and would make me happier about it being in the LoraConfig object.

@BenjaminBossan
Copy link
Member

Thanks a lot for explaining this in more detail, what you describe makes intuitive sense. Let's extend the help section of the new parameter to clearly indicate when users should set it and what they can expect.

what I refer to above as ephemeral transfers

Okay, I was trying to google this and didn't find any description :)

2. Lora A/B are on the CPU, and we are using ephemeral transfers: copies of them are put on the GPU, the weight norm is calculated (ending up on the GPU)

Okay, I assume that this part is automatically taken care of already...

and is then explicitly moved to CPU, which is expected and desired.

but it's not yet clear to me where this happens.

This is the relevant code inside Axolotl (patched to add ephemeral_transfers):

I haven't worked with axolotl yet, so this unfortunately doesn't help me a lot.

I wonder if we could craft a PEFT example, or even better, a unit test, where we can check that this works correctly (or at least results in the expected speed up). Having a unit test would be really great to ensure that there are no regressions of this with future versions of PEFT.

If we had a good way to mark something as "do not store in final adapter_config.json file", that would be ideal and would make me happier about it being in the LoraConfig object.

The first idea that comes to mind would be to have a private attribute on the config that stores all the attribute names that should not be persisted, something like _attributes_not_to_persist: list[str]. Then when the save_pretrained method of the config is called, we pop all those attributes from the dict, then the attribute itself. WDYT?

@kallewoof
Copy link
Contributor Author

Thanks a lot for explaining this in more detail, what you describe makes intuitive sense. Let's extend the help section of the new parameter to clearly indicate when users should set it and what they can expect.

Will do. I also would like to test inference/training and see if it might have a positive impact there, but I believe e.g. for the dora_init case, it is only called during initialization and then never again, so its impact should be rather small. I also don't think CPU offloading for stuff (especially lora layers) is broadly supported atm, but I could be mistaken.

I am also beginning to believe that what I am doing is similar to what the unsloth people are doing. I haven't looked at their code, but IIUC, they juggle things to/from CPU/GPU to speed things up, which is exactly what this method is about.

what I refer to above as ephemeral transfers

Okay, I was trying to google this and didn't find any description :)

Sorry, should've been more clear!

  1. Lora A/B are on the CPU, and we are using ephemeral transfers: copies of them are put on the GPU, the weight norm is calculated (ending up on the GPU)

Okay, I assume that this part is automatically taken care of already...

Yes. A torch operation such as torch.linalg.norm() will produce results on the same device as the parameters, which for case 2 above is the GPU.

and is then explicitly moved to CPU, which is expected and desired.

but it's not yet clear to me where this happens.

When case 2 above is true, the place_on_cpu flag is set in the call to update_layer() in DoraLinearLayer. Before return, the below code happens:

if place_on_cpu:
weight_norm = weight_norm.to("cpu")

I haven't worked with axolotl yet, so this unfortunately doesn't help me a lot.

Got it. I just meant to say that the speed-up happens inside the call to PeftModel.from_pretrained(), down the line, when it does the dora_init calls.

I wonder if we could craft a PEFT example, or even better, a unit test, where we can check that this works correctly (or at least results in the expected speed up). Having a unit test would be really great to ensure that there are no regressions of this with future versions of PEFT.

That sounds like a good idea, but it relies on there being GPUs available for the unit tests. I haven't played with unit tests in PEFT yet so I'm clueless on how that works.

If we had a good way to mark something as "do not store in final adapter_config.json file", that would be ideal and would make me happier about it being in the LoraConfig object.

The first idea that comes to mind would be to have a private attribute on the config that stores all the attribute names that should not be persisted, something like _attributes_not_to_persist: list[str]. Then when the save_pretrained method of the config is called, we pop all those attributes from the dict, then the attribute itself. WDYT?

That works for me! Sounds simple enough.

@BenjaminBossan
Copy link
Member

Will do. I also would like to test inference/training and see if it might have a positive impact there

Nice, thanks.

I am also beginning to believe that what I am doing is similar to what the unsloth people are doing. I haven't looked at their code, but IIUC, they juggle things to/from CPU/GPU to speed things up, which is exactly what this method is about.

I don't know either but I think there is more going on there than just this.

When case 2 above is true, the place_on_cpu flag is set in the call to update_layer() in DoraLinearLayer. Before return, the below code happens:

Okay, so this is purely for initialization. I wonder if it would even work for inference.

That sounds like a good idea, but it relies on there being GPUs available for the unit tests. I haven't played with unit tests in PEFT yet so I'm clueless on how that works.

We have unit tests in tests/test_gpu_examples.py that will be run on workers with access to NVIDIA GPUs. Therefore, it's best to put the test there. Check the other tests there for inspiration and feel free to ask if you need help.

That works for me! Sounds simple enough.

👍

8000
@kallewoof
Copy link
Contributor Author

That works for me! Sounds simple enough.

I thought about this a tiny bit more, and I think having a runtime dict inside the config that is dropped when present is even more straightforward.

@kallewoof kallewoof force-pushed the 202406-ephemeral-transfers branch from dc2eae7 to e138ea2 Compare June 17, 2024 13:36
@BenjaminBossan
Copy link
Member

I thought about this a tiny bit more, and I think having a runtime dict inside the config that is dropped when present is even more straightforward.

I think this should also work. It's important that this is backwards compatible (new PEFT can load old configs) and ideally also forwards compatible (old PEFT can load new configs), but from the looks of it, it should work.

Let me know once this is ready for review. Please add a test for the new config option to ensure that it is indeed not stored. Testing the feature itself would also be nice as mentioned, if this isn't really feasible, an example would be nice.

@kallewoof
Copy link
Contributor Author

@BenjaminBossan Thanks, I have added tests for the runtime config option, as well as a simple test that the ephemeral transfers actually does speed up the DoRA initialization as advertised. It's a bit flaky, though, due to the tiny size of the model (facebook/opt-125m), but I think with a ramped up rank it should be OK. Feedback welcome.

@kallewoof kallewoof marked this pull request as ready for review June 18, 2024 01:12
@kallewoof kallewoof force-pushed the 202406-ephemeral-transfers branch from a9ec2c0 to 00ea733 Compare June 18, 2024 01:16
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member
@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the additions. The PR is on a good way, but let's harden it a little bit more. Please check my comments.

Apart from those, some more general points:

It's a bit flaky, though, due to the tiny size of the model (facebook/opt-125m), but I think with a ramped up rank it should be OK.

How big is the variance in your tests? I currently can't test but I can do some testing next week. If it doesn't take too long, you could consider loading a few times and taking the average/median to compare. Or use a larger model if that helps.

Also, did you have the opportunity to check if/how this affects inference/training?

Finally, please run make style so that CI will pass successfully.

Args:
ephemeral_transfers (`bool`): Whether to use ephemeral transfers for models partially kept in CPU memory.
"""
ephemeral_transfers: bool = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still wondering about the name. Maybe we can find something that is a bit more specific. Just as an example, it could be optimize_device_memory or something like this. Of course, we may add more optimizations in the future, so it's hard to predict if that name will fit in the future. But I feel like users will be confused when reading "ephemeral transfers".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the name is probably not good as is. optimize_device_memory sounds like it decreases VRAM usage, when in reality it gives a tiny momentary overhead. I'll think of "normal" succinct ways of saying it.

Copy link
Member
@BenjaminBossan BenjaminBossan Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point. "optimize" does not necessarily mean "decrease", but I agree that this is what many users might think. LMK if you come up with a better name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds more or less (but I agree not exactly) as some sort of temporary CPU offloading. What about something like ephemeral_cpu_offload ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the idea is to keep it in CPU memory all the time and, just as it is needed for some heavy calculation, copy it to GPU, and then remove it when finished. "Temporary CPU offloading" kind of sounds like the opposite of that, i.e. it's kept on GPU and then temporarily put in CPU memory for brief periods.

Copy link
Contributor
@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this very nice feature, I left some minor comments. Could you also document this feature for end users in developper_guides/lora ?)

Args:
ephemeral_transfers (`bool`): Whether to use ephemeral transfers for models partially kept in CPU memory.
"""
ephemeral_transfers: bool = field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds more or less (but I agree not exactly) as some sort of temporary CPU offloading. What about something like ephemeral_cpu_offload ?

@kallewoof
Copy link
Contributor Author

Thanks a lot for the detailed feedback! Work pulled me away for a little bit, but I will address once that's out of the way.

@kallewoof kallewoof force-pushed the 202406-ephemeral-transfers branch from 00ea733 to c1e254e Compare June 27, 2024 01:53
@kallewoof kallewoof force-pushed the 202406-ephemeral-transfers branch from d3314f0 to df11b68 Compare June 27, 2024 05:27
@kallewoof
Copy link
Contributor Author

@BenjaminBossan @younesbelkada I believe I addressed all feedback. Let me know if I missed something and I will address.

@@ -79,7 +79,8 @@ def __init__(self, base_layer: nn.Module) -> None:

@property
@abstractmethod
def _available_adapters(self) -> set[str]: ...
def _available_adapters(self) -> set[str]:
...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make style did this. Should I revert?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary. This change should already be on latest main, so you could merge with/rebase on main to remove this diff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's the case. I rebased on main before I started work on the changes earlier.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, then it's most likely because you have an older ruff version installed. Try v0.4.10.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was it!

with tempfile.TemporaryDirectory() as tmp_dirname:
for model_name, revision in PEFT_MODELS_TO_TEST:
cfg = config_class.from_pretrained(model_name, revision=revision)
cfg.runtime.ephemeral_transfers = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why this is not erroring since it should only work for LoraConfig.

Copy link
Member
@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. Not much is missing at this point. I have a few comments left, please take a look. Also, anything about what I wrote earler:

How big is the variance in your tests? I currently can't test but I can do some testing next week. If it doesn't take too long, you could consider loading a few times and taking the average/median to compare. Or use a larger model if that helps.

Also, did you have the opportunity to check if/how this affects inference/training?

I'll also run some checks once I'm back at my machine.

ephemeral_transfers (`bool`): Whether to use ephemeral transfers for models partially kept in CPU memory.
"""

ephemeral_transfers: bool = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you have some ideas for a possibly better name? Younes had suggested ephemeral_cpu_offload, WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave my thoughts on that #1857 (comment) -- I think maybe ephemeral_gpu_offload would be more accurate, but not sure either of those make the thing easier to understand.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That name works for me.

@kallewoof
Copy link
Contributor Author

Thanks for the updates. Not much is missing at this point. I have a few comments left, please take a look. Also, anything about what I wrote earler:

How big is the variance in your tests? I currently can't test but I can do some testing next week. If it doesn't take too long, you could consider loading a few times and taking the average/median to compare. Or use a larger model if that helps.

After bumping r I stopped seeing intermittent errors. I'm not sure how to see the actual variance other than running a bunch of times and checking the values, and even then, that's all very environment specific. If things start erroring out intermittently on CI or elsewhere, using a larger model should address the problem.

(For the record, right before I wrote that code, I spent ~20 minutes waiting for an initialization of a 70B model and a CPU-offloaded DoRA before giving up. With the ephemeral transfers code, the operation took less than a minute to complete, which is what triggered me to suggest this fix in the first place.)

Also, did you have the opportunity to check if/how this affects inference/training?

I didn't do any checks, but I thought about the code and concluded the following: the dora_init method is only called when the LoRA adapters are initialized during the startup phase. Inference and training should not be affected, except perhaps during startup, in the highly unusual case of using CPU offloaded adapters during inference or training (can you even do that?).

@kallewoof kallewoof force-pushed the 202406-ephemeral-transfers branch from 7ffa65b to e735dda Compare June 27, 2024 14:18
@@ -79,7 +79,8 @@ def __init__(self, base_layer: nn.Module) -> None:

@property
@abstractmethod
def _available_adapters(self) -> set[str]: ...
def _available_adapters(self) -> set[str]:
...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, then it's most likely because you have an older ruff version installed. Try v0.4.10.

@kallewoof kallewoof force-pushed the 202406-ephemeral-transfers branch from e735dda to 330fbb3 Compare June 27, 2024 14:22
@kallewoof
Copy link
Contributor Author

I bumped the model to opt-350b, after seeing an intermittent failure on my end. Feel free to at-bump me after merging if this is causing issues and I'll adjust it to be more stable.

Copy link
Member
@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the updates. I think if we get the last few kinks ironed out, this PR should be good.

def test_dora_ephemeral_transfers(self):
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-350m",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the tests on a 4090 and unfortunately, it is very flaky for me, failing in 4 out of 5 tries with

assert 2.1404892549999204 > (1.1 * 1.980610250000609)

and similar numbers. I also tried a bigger model, bloomz-560m, but that still failed. I wonder if we should scrap the performance test if the gain doesn't materialize on smaller models (and we don't want to load huge models during testing either).

Not sure what we can really test instead, maybe just that there is no error and the device, similar to the multi-GPU test?

As a "proof" that this works, I would suggest to instead add a short script to the examples/ directory that uses a bigger model where there is a very noticeable gain. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that seems to be a better approach. My intention is to expand this in other areas post-merge. Hopefully one of those is more obvious.

I'll rewrite this to check that the resulting weights are identical instead, which ensures that this code does not screw anything up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added example that loads a DoRA onto a model, and optionally merges and saves it to disk, while also timing it.

The one annoying part is that it will obviously not get the time right if the model is not yet in the HF hub cache and has to be downloaded first. I might look into whether I can force a "download only if necessary" call before the actual model/DoRA load part.

Copy link
Contributor Author
@kallewoof kallewoof Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put in snapshot_download calls before the timers. It downloads some unnecessary stuff but those aren't very large, and it looks like it works otherwise. It try/excepts, so if you provide a local model or something, it will error but continue and work as normal in the end:

python load_with_dora.py --ephemeral_gpu_offload --dora=./m7dora-example 
Fetching 13 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 72798.33it/s]
Failed to download DoRA: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './m7dora-example'.
-------------------- Loading model --------------------
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.20it/s]
-------------------- Loading PeftModel --------------------
-------------------- Done --------------------
Model loading time: 3.97s
PeftModel loading time: 16.95s
Use ephemeral GPU offloading: True

ephemeral_transfers (`bool`): Whether to use ephemeral transfers for models partially kept in CPU memory.
"""

ephemeral_transfers: bool = field(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That name works for me.

@kallewoof kallewoof changed the title ephemeral transfer support ephemeral GPU offload support Jul 1, 2024
@kallewoof kallewoof force-pushed the 202406-ephemeral-transfers branch from 23a155c to 5e1c986 Compare July 1, 2024 02:12
@kallewoof
Copy link
Contributor Author

Thanks for the review! I believe I addressed all issues.

Copy link
Member
@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the tests are better now without the flakiness and the added example works well enough to illustrate the advantage, thanks a lot. I have a few suggestions to improve that example, and still one issue regarding the config file. Please check.

main()

"""
Example outputs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to put this docstring to the top of the file and integrate the comment into the docstring. Also, please add the license notice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have messed this up, but I tried doing what you asked.

-------------------- Loading PeftModel --------------------
-------------------- Done --------------------
Model loading time: 4.28s
PeftModel loading time: 16.59s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a data point, I got 20s vs 10s on my machine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try a 70B model/DoRA next. ;)

# We continue anyway as this might be e.g. a local directory or something

start = time.perf_counter()
print("-" * 20 + " Loading model " + "-" * 20)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the "-" * 20 parts are unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about "--- text ---"?

Copy link
Member
@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have messed this up, but I tried doing what you asked.

No, looks good.

Try a 70B model/DoRA next. ;)

As soon as I have a sufficiently large GPU lying around :-p

@kallewoof
Copy link
Contributor Author
kallewoof commented Jul 2, 2024

@BenjaminBossan The thread is buggy as hell, so I'm commenting here instead:

Hmm, but I don't see the harm in adding a check to PeftModel.from_pretrained and removing the runtime config there. The example code could be considered "outside PEFT", but using asdict on a PEFT config and getting a valid dict representation would be a reasonable assumption for a package that builds on top of PEFT.

OK, I think I got a good solution. How does f36ec50 look to you?

@kallewoof kallewoof force-pushed the 202406-ephemeral-transfers branch from 40ff66b to f36ec50 Compare July 2, 2024 00:03
@kallewoof
Copy link
Contributor Author

I don't think these are related to my changes?

  • 3.10, windows-latest: FAILED tests/test_tuners_utils.py::TestModelAndLayerStatus::test_base_model_type_transformers_automodel - FileNotFoundError: No such file or directory: "C:\Users\runneradmin\.cache\huggingface\hub\models--google--flan-t5-small\snapshots\0fc9ddf78a1e988dac52e2dac162b0ede4fd74ab\model.safetensors"
  • 3.11, macos-12: FAILED tests/test_stablediffusion.py::StableDiffusionModelTester::test_add_weighted_adapter_base_unchanged_0_test_hf_internal_testing_tiny_stable_diffusion_torch_lora - huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.

Copy link
Member
@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this nice runtime optimization for DoRA (and hopefully more in the future). Great work and thanks for your patience with the review.

I don't think these are related to my changes?

Yes, don't worry, these are just issues with the CI getting timeouts from the Hub.

@BenjaminBossan BenjaminBossan merged commit 1e2258d into huggingface:main Jul 2, 2024
14 checks passed
@kallewoof kallewoof deleted the 202406-ephemeral-transfers branch July 2, 2024 10:31
Guy-Bilitski pushed a commit to Guy-Bilitski/peft that referenced this pull request May 13, 2025
Adds the concept of ephemeral GPU offloading, i.e. where data in compute
intense operations is copied onto the GPU before the operation is
performed, after which the result is put back on CPU memory.

This PR adds support in the DoRA initialization code, but the approach
can be applied in a number of places: when the size of the data compared
to the time to perform the operation on CPU memory is heavily time
dominant, using ephemeral transfers has a fairly small VRAM overhead
(depending on the size of the model/adapter) with orders of magnitude
speed-up in certain operations.

For example, a Llama3-8B DoRA adapter with r=64 would put an overhead of
2 x (64 x 4096 x 2 + 4096 x 4096) bytes (assuming fp16), i.e. 33 MB or
so. A Llama3-70B adapter with r=32 would have 2 x (32 x 8192 x 2 + 8192
x 8192) bytes =130 MB.

By making use of ephemeral GPU offloading, more efficient juggling of
data between GPU and CPU may become possible, i.e. where instead of
always loading as much as we can onto the GPU and then endure the CPU
slowness for whatever happens to not fit in there, we intentionally
leave a (modest) chunk of VRAM for optimizations like these, and the end
result is a much (MUCH) faster experience.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0