8000 Check whether device support bfloat16 (#2653) · InferenceNexus/lmdeploy@d8f9e35 · GitHub
[go: up one dir, main page]

Skip to content

Commit d8f9e35

Browse files
authored
Check whether device support bfloat16 (InternLM#2653)
* fallback to float16 if torch.cuda.is_bf16_supported is False * fix coverting qwen2-awq failed * update * update * rollback load.py * update
1 parent 89f52bc commit d8f9e35

File tree

8 files changed

+62
-16
lines changed

8 files changed

+62
-16
lines changed

lmdeploy/pytorch/check_env/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,21 @@ def __check_model_dtype_support(config):
208208
import torch
209209

210210
from lmdeploy.pytorch.config import ModelConfig
211+
from lmdeploy.utils import is_bf16_supported
211212

212213
try:
213214
model_config = ModelConfig.from_hf_config(config,
214215
model_path=model_path,
215216
dtype=dtype)
216217
if model_config.dtype == torch.bfloat16:
217-
assert torch.cuda.is_bf16_supported(), (
218+
assert is_bf16_supported(), (
218219
'bf16 is not supported on your device')
219220
except AssertionError as e:
220-
message = (f'Your device does not support `{model_config.dtype}`. '
221-
'Try edit `torch_dtype` in `config.json`.\n'
222-
'Note that this might have negative effect!')
221+
message = (
222+
f'Your device does not support `{model_config.dtype}`. '
223+
'You can set `dtype` to float16 in PyTorchEngineConfig or '
224+
'`--dtype float16` to api_server.\n'
225+
'Note that this might have negative effect!')
223226
_handle_exception(e, 'Model', logger, message=message)
224227
except Exception as e:
225228
message = (f'Checking failed with error {e}',

lmdeploy/pytorch/configurations/cogvlm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@ def condition(cls, hf_config):
1414
@classmethod
1515
def build(cls, hf_config, model_path: str = None):
1616
"""build."""
17-
import torch
17+
from lmdeploy.utils import is_bf16_supported
1818
cfg = DefaultModelConfigBuilder.build(hf_config)
1919
if getattr(hf_config, 'num_multi_query_heads', None):
2020
cfg.num_key_value_heads = hf_config.num_multi_query_heads
2121
cfg.cogvlm_style = True
22-
torch_dtype = 'bfloat16' if torch.cuda.is_bf16_supported(
23-
) else 'float16'
22+
torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16'
2423
hf_config.torch_dtype = torch_dtype
2524
return cfg

lmdeploy/pytorch/configurations/qwen.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@ def condition(cls, hf_config):
1313
@classmethod
1414
def build(cls, hf_config, model_path: str = None):
1515
"""build."""
16-
import torch
16+
from lmdeploy.utils import is_bf16_supported
1717
cfg = DefaultModelConfigBuilder.build(hf_config)
1818
if cfg.bos_token_id is None:
1919
cfg.bos_token_id = 151644
2020
if cfg.eos_token_id is None:
2121
cfg.eos_token_id = 151645
2222

23-
is_bf16_supported = torch.cuda.is_bf16_supported()
24-
torch_dtype = 'bfloat16' if is_bf16_supported else 'float16'
25-
if hf_config.bf16 and is_bf16_supported:
23+
torch_dtype = 'bfloat16' if is_bf16_supported() else 'float16'
24+
if hf_config.bf16 and is_bf16_supported():
2625
torch_dtype = 'bfloat16'
2726
elif hf_config.fp16:
2827
torch_dtype = 'float16'

lmdeploy/turbomind/deploy/converter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from lmdeploy.model import MODELS, best_match_model
1212
f 67E6 rom lmdeploy.utils import get_logger, get_model
1313

14-
from ...utils import _get_and_verify_max_len
14+
from ...utils import _get_and_verify_max_len, is_bf16_supported
1515
from ..supported_models import SUPPORTED_ARCHS, is_supported
1616
from .config import TurbomindModelConfig
1717
from .exporter import get_exporter_factory
@@ -138,6 +138,10 @@ def get_output_model_registered_name_and_config(model_path: str,
138138
else:
139139
assert 0, f'unsupported specified data type {dtype}'
140140

141+
if weight_type == 'bfloat16' and not is_bf16_supported():
142+
logger.warn('data type fallback to float16 since '
143+
'torch.cuda.is_bf16_supported is False')
144+
weight_type = 'float16'
141145
config.model_config.model_arch = model_arch
142146
config.model_config.weight_type = weight_type
143147
config.model_config.model_format = model_format
@@ -226,7 +230,7 @@ def get_tm_model(model_path,
226230
f'mismatched quant method: user input ' \
227231
f'"{engine_config.model_format}" ' \
228232
f'vs model quant_config "{quant_method}"'
229-
assert group_size is None or group_size == _group_size, \
233+
assert not group_size or group_size == _group_size, \
230234
f'mismatched quant group size: user input "{group_size}" ' \
231235
f'vs model quant_config "{_group_size}"'
232236

lmdeploy/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,40 @@ def get_max_batch_size(device_type: str):
352352
return 16
353353
elif device_type == 'maca':
354354
return 128
355+
356+
357+
def is_bf16_supported(device_type: str = 'cuda'):
358+
"""Check if device support bfloat16.
359+
360+
Args:
361+
device_type (str): the type of device
362+
"""
363+
364+
if device_type == 'cuda':
365+
import torch
366+
device = torch.cuda.current_device()
367+
368+
# Check for CUDA version and device compute capability.
369+
# This is a fast way to check for it.
370+
cuda_version = torch.version.cuda
371+
if (cuda_version is not None and int(cuda_version.split('.')[0]) >= 11
372+
and torch.cuda.get_device_properties(device).major >= 8):
373+
return True
374+
else:
375+
return False
376+
elif device_type == 'ascend':
377+
# The following API doesn't work somehow in multi-npu devices. Due to
378+
# the `ascend910` device's capability to support bfloat16, we are
379+
# returning true as a workaround
380+
return True
381+
# import torch_npu
382+
# device_name = torch_npu.npu.get_device_name(0)[:10]
383+
# device_name = device_name.lower()
384+
# if device_name.startwith('ascend910'):
385+
# return True
386+
# else:
387+
# return False
388+
elif device_type == 'maca':
389+
return True
390+
else:
391+
return False

tests/pytorch/kernel/test_apply_rotary.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33

44
from lmdeploy.pytorch.kernels import apply_rotary_pos_emb
5+
from lmdeploy.utils import is_bf16_supported
56

67

78
def _rotate_half(x):
@@ -12,7 +13,7 @@ def _rotate_half(x):
1213

1314

1415
def _bf16_mark():
15-
return pytest.mark.skipif(not torch.cuda.is_bf16_supported(),
16+
return pytest.mark.skipif(not is_bf16_supported(),
1617
reason='bf16 not supported.')
1718

1819

tests/pytorch/kernel/test_multinomial_sampling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import torch
33

44
from lmdeploy.pytorch.kernels import multinomial_sampling
5+
from lmdeploy.utils import is_bf16_supported
56

67

78
def _bf16_mark():
8-
return pytest.mark.skipif(not torch.cuda.is_bf16_supported(),
9+
return pytest.mark.skipif(not is_bf16_supported(),
910
reason='bf16 not supported.')
1011

1112

tests/pytorch/kernel/test_rms_norm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import pytest
22
import torch
33

4+
from lmdeploy.utils import is_bf16_supported
5+
46

57
def _bf16_mark():
6-
return pytest.mark.skipif(not torch.cuda.is_bf16_supported(),
8+
return pytest.mark.skipif(not is_bf16_supported(),
79
reason='bf16 not supported.')
810

911

0 commit comments

Comments
 (0)
0