8000 Fix HF loading when there's no metadata file to work with fsspec (#15… · pytorch/pytorch@916f6ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 916f6ba

Browse files
ankitageorgepytorchmergebot
authored andcommitted
Fix HF loading when there's no metadata file to work with fsspec (#152856)
Summary: HF loading when there is no metadata is an edge case for some users. We were previously calling safe_open(filename) to get the k 8000 eys in the safetensors file, but this doesn't work with fsspec, when models have a different backend than local fs (ie. hf, s3 etc). This diff updates to open the file with fsspec.open() and then safetensors.deserialize() to get the keys Test Plan: unit test and e2e test reading from hf Differential Revision: D74181513 Pull Request resolved: #152856 Approved by: https://github.com/joecummings
1 parent e06a080 commit 916f6ba

File tree

2 files changed

+33
-14
lines changed

2 files changed

+33
-14
lines changed

test/distributed/checkpoint/test_hf_storage.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -192,20 +192,24 @@ def test_metadata_hf(self) -> None:
192192

193193
def test_read_metadata_when_metadata_file_does_not_exist(self) -> None:
194194
mock_module = MagicMock()
195-
sys.modules["safetensors.torch"] = mock_module
196195
sys.modules["huggingface_hub"] = mock_module
196+
197197
with tempfile.TemporaryDirectory() as path:
198198
reader = _HuggingFaceStorageReader(path=path)
199199
reader.fs = FileSystem()
200200
# there is one safetensor file, but no metadata file,
201201
# so we create metadata from the safetensor file
202-
file_name = "test.safetensors"
203-
open(os.path.join(path, file_name), "w").close()
204-
205202
keys = ["tensor_0", "tensor_1"]
206-
mock_module.safe_open.return_value.__enter__.return_value.keys.return_value = (
207-
keys
208-
)
203+
file_name = "test.safetensors"
204+
with open(os.path.join(path, file_name), "wb") as f:
205+
# write metadata the same way it would be in safetensors file
206+
metadata_contents = json.dumps(
207+
{"tensor_0": "value_0", "tensor_1": "value_1"}
208+
)
209+
metadata_bytes = metadata_contents.encode("utf-8")
210+
211+
f.write(len(metadata_bytes).to_bytes(8, byteorder="little"))
212+
f.write(metadata_bytes)
209213

210214
metadata = reader.read_metadata()
211215

torch/distributed/checkpoint/_hf_storage.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# mypy: allow-untyped-defs
22
import dataclasses
3+
import io
34
import json
45
import os
56
import queue
7+
import struct
68
from typing import Optional
79

810
import fsspec # type: ignore[import-untyped]
@@ -225,22 +227,22 @@ def read_metadata(self) -> Metadata:
225227

226228
if not self.fs.exists(metadata_path):
227229
# if metadata file doesn't exist, create it from the safetensors file
228-
from safetensors.torch import safe_open # type: ignore[import-not-found]
229-
230230
safetensors_files = []
231231
for file in self.fs.ls(self.path):
232232
if file.endswith(SUFFIX):
233-
safetensors_files.append(os.path.basename(file))
233+
safetensors_files.append(file)
234234

235235
if len(safetensors_files) != 1:
236236
raise ValueError(
237237
f"Need exactly one safetensors file to load without metadata, found {len(safetensors_files)} files"
238238
)
239239
storage_data = {}
240-
with safe_open(safetensors_files[0], framework="pt") as f:
241-
for k in f.keys():
242-
state_dict_metadata[k] = BytesStorageMetadata()
243-
storage_data[k] = safetensors_files[0]
240+
with self.fs.create_stream(safetensors_files[0], "rb") as f:
241+
keys = _get_safetensors_file_keys(f)
242+
243+
for key in keys:
244+
state_dict_metadata[key] = BytesStorageMetadata()
245+
storage_data[key] = os.path.basename(safetensors_files[0])
244246
else:
245247
with self.fs.create_stream(metadata_path, "r") as metadata_file:
246248
metadata = json.load(metadata_file)
@@ -259,3 +261,16 @@ def read_metadata(self) -> Metadata:
259261
metadata.storage_meta.load_id = self.load_id
260262

261263
return metadata
264+
265+
266+
def _get_safetensors_file_keys(file_bytes: io.IOBase) -> list[str]:
267+
# this uses the same logic that's done in HF code base
268+
# https://github.com/2404589803/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L5308
269+
# and follows their documentation on how their files are serialized
270+
# https://huggingface.co/docs/safetensors/index#format
271+
272+
header_len_bytes = file_bytes.read(8)
273+
header_len = struct.unpack("<Q", header_len_bytes)[0]
274+
header_json = file_bytes.read(header_len)
275+
metadata = json.loads(header_json)
276+
return list(metadata.keys())

0 commit comments

Comments
 (0)
0