8000 `AutoModel.from_pretrained(...)` fails under `with torch.device("meta")` with PyTorch 2.7.0 · Issue #153332 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
AutoModel.from_pretrained(...) fails under with torch.device("meta") with PyTorch 2.7.0 #153332
@vadimkantorov

Description

@vadimkantorov
# from torch.nn.attention.flex_attention import BlockMask, flex_attention
from transformers import AutoModel
import torch

with torch.device('meta'):
    AutoModel.from_pretrained('Qwen/Qwen2.5-0.5B', trust_remote_code=True)

fails with:

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
[<ipython-input-1-00ba4c43be18>](https://localhost:8080/#) in <cell line: 0>()
      4 
      5 with torch.device('meta'):
----> 6     AutoModel.from_pretrained('Qwen/Qwen2.5-0.5B', trust_remote_code=True)

6 frames
[/usr/local/lib/python3.11/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    569             if model_class.config_class == config.sub_configs.get("text_config", None):
    570                 config = config.get_text_config()
--> 571             return model_class.from_pretrained(
    572                 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    573             )

[/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in _wrapper(*args, **kwargs)
    277         old_dtype = torch.get_default_dtype()
    278         try:
--> 279             return func(*args, **kwargs)
    280         finally:
    281             torch.set_default_dtype(old_dtype)

[/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
   4397                 offload_index,
   4398                 error_msgs,
-> 4399             ) = cls._load_pretrained_model(
   4400                 model,
   4401                 state_dict,

[/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in _load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)
   4831             # Skip it with fsdp on ranks other than 0
   4832             elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
-> 4833                 disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
   4834                     model_to_load,
   4835                     state_dict,

[/usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py](https://localhost:8080/#) in decorate_context(*args, **kwargs)
    114     def decorate_context(*args, **kwargs):
    115         with ctx_factory():
--> 116             return func(*args, **kwargs)
    117 
    118     return decorate_context

[/usr/local/lib/python3.11/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in _load_state_dict_into_meta_model(model, state_dict, shard_file, expected_keys, reverse_renaming_mapping, device_map, disk_offload_folder, disk_offload_index, cpu_offload_folder, cpu_offload_index, hf_quantizer, is_safetensors, keep_in_fp32_regex, unexpected_keys, device_mesh)
    822                     param_device = "cpu" if is_local_dist_rank_0() else "meta"
    823 
--> 824                 _load_parameter_into_model(model, param_name, param.to(param_device))
    825 
    826             else:

[/usr/local/lib/python3.11/dist-packages/torch/utils/_device.py](https://localhost:8080/#) in __torch_function__(self, func, types, args, kwargs)
    102         if func in _device_constructors() and kwargs.get('device') is None:
    103             kwargs['device'] = self.device
--> 104         return func(*args, **kwargs)
    105 
    106 # NB: This is directly called from C++ in torch/csrc/Device.cpp

NotImplementedError: Cannot copy out of meta tensor; no data!

Also, unless uncommenting the first line, it also fails on 2.6.0 with RuntimeError: Tensor.item() cannot be called on meta tensors:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0