[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out of memory on H800 #7

Open
Lucas-TY opened this issue Jun 25, 2024 · 8 comments
Open

Out of memory on H800 #7

Lucas-TY opened this issue Jun 25, 2024 · 8 comments
Assignees
Labels

Comments

@Lucas-TY
Copy link
CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 124928 --budget 4096 \
 --chunk_size 8 --top_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.65s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.18it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 124928
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
  File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in <module>
    graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 144, in initialize_cuda_graph
    self.callables[gamma_offset] = draft_run_capture_graph(
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 83, in draft_run_capture_graph
    static_logits = engine.draft_run(input_ids=static_input_ids, gamma_offset=gamma_offset, probs=probs, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 54, in draft_run
    logits = self.draft(input_ids=input_ids, kv_cache=self.draft_cache, graph_cache=self.draft_cache, gamma_offset=gamma_offset).logits
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 340, in forward
    outputs = self.model(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 301, in forward
    layer_outputs = decoder_layer(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 220, in forward
    hidden_states = self.self_attn(
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama_68m.py", line 141, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/medusa/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 
@preminstrel
Copy link
Contributor
preminstrel commented Jun 25, 2024

Hello, may I ask the memory for your device? You can try to decrease prefill from 124928 to 122880 to see if it is still OOM. The code can run on a single A100-80GB.

@preminstrel preminstrel self-assigned this Jun 25, 2024
@Lucas-TY
Copy link
Author
Lucas-TY commented Jun 26, 2024

122880
I got this error

CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096  --chunk_size 8 --t
op_p 0.9 --temp 0.6 --gamma 6
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.85s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:07<00:00,  2.82it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 122880
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)...
Traceback (most recent call last):
  File "/home/lliee/workspace_tianyu/TriForce/test/on_chip.py", line 83, in <module>
    graph_engine.initialize_cuda_graph(gamma, probs=True, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 154, in initialize_cuda_graph
    self.callable_model_verify = model_verify_capture_graph(
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 111, in model_verify_capture_graph
    static_logits = engine.model_verify(input_ids=static_input_ids, position_ids=static_position_ids, probs=probs, temperature=temperature, top_p=top_p)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/utils/graph_infer.py", line 63, in model_verify
    logits = self.model(input_ids=input_ids, kv_cache=self.kv_cache, graph_cache=self.graph_cache, position_ids=position_ids, spec=True).logits
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 396, in forward
    outputs = self.model(
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 355, in forward
    layer_outputs = decoder_layer(
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 276, in forward
    hidden_states = self.self_attn(
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lliee/workspace_tianyu/TriForce/models/modeling_llama.py", line 222, in forward
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  File "/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 182, in apply_rotary_pos_emb
    q_embed = (q * cos) + (rotate_half(q) * sin)
RuntimeError: The size of tensor a (32) must match the size of tensor b (131072) at non-singleton dimension 1

@Lucas-TY
Copy link
Author
nvidia-smi
Wed Jun 26 10:40:05 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H800                    On  | 00000000:61:00.0 Off |                    0 |
| N/A   30C    P0              70W / 700W |      4MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

@preminstrel
Copy link
Contributor
preminstrel commented Jun 26, 2024

What is your transformers version? Can you set it to transformers==4.37.2 since apply_rotary_pos_emb api changes for recent versions?

@Lucas-TY
Copy link
Author

thanks, I think it's more likely to be an environment problem

CUDA_VISIBLE_DEVICES=0 python test/on_chip.py --prefill 122880 --budget 4096  --chunk_size 8 
--top_p 0.9 --temp 0.6 --gamma 6
/home/lliee/miniconda3/envs/TriForce/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00,  1.65s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:06<00:00,  3.15it/s]
####################################### Config #######################################
Method: TriForce
Dataset: gs
Spec Args: {'budget': 4096, 'chunk_size': 8}
Draft: JackFram/llama-68m
Target: NousResearch/Yarn-Llama-2-7b-128k
Prefill Length: 122880
Generation Length: 256
Gamma: 6
Sampling Method: top_k = -1, top_p = 0.9, temperature = 0.6
Log CSV: None
######################################################################################

[draft run] capturing graph for 0 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 1 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 2 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 3 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 4 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 5 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 6 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 7 (probs=True, temp=0.6, top_p=0.9)...
[draft run] capturing graph for 8 (probs=True, temp=0.6, top_p=0.9)...
[model verify] capturing graph for spec len 6 (probs=True, temp=0.6, top_p=0.9)...
[Full Cache] Cached: 0 | Budget: 123152
[Retrieval Cache] Budget: 4096  | PreFill: 122880  | Chunk Size: 8  | Chunks: 15360  | Select Sets: 512
[StreamingLLM Cache] Start Size: 16 | Recent Size: 234 | Gamma: 6 | Real Budget: 259 | Cached: 0
tokenized_prompts length: 20
Autoregressive Warmup: 100%|████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:37<00:00, 37.44s/it]
Autoregressive Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:35<00:00, 35.70s/it]
[Autoregressive] average latency: 31.812156550586224 ms
TriForce Warmup: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:44<00:00, 34.95s/it]
TriForce Test: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [11:08<00:00, 33.43s/it]
average acceptance rate (NOT per token): 0.6893210335795382
[TriForce] average latency: 17.238558956780214 ms
[E2E Speedup]: 1.8454069525384529

@Lucas-TY
Copy link
Author

I changed the version of the transformer because there are other environment issues, such as:

@preminstrel
Copy link
Contributor

Yeah I am using CUDA 12.1. Here is my flash_attn version.

>>> import torch
>>> torch.__version__
'2.2.1+cu121'
>>> import flash_attn
>>> flash_attn.__version__
'2.5.7'

@preminstrel
Copy link
Contributor

I have added a FAQ in the README :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants