8000 SGLang HiCache NIXL Connector (#8488) · sgl-project/sglang@2cd2e27 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2cd2e27

Browse files
vvenkates27mkhazraeexiezhq-hermann
authored
SGLang HiCache NIXL Connector (#8488)
Signed-off-by: Vishwanath Venkatesan <vvenkatesan@nvidia.com> Co-authored-by: Moein Khazraee <moein@nvidia.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
1 parent 743638b commit 2cd2e27

File tree

8 files changed

+837
-9
lines changed

8 files changed

+837
-9
lines changed

python/sglang/srt/managers/cache_controller.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ def __init__(
265265
if storage_backend == "file":
266266
self.storage_backend = HiCacheFile()
267267
self.get_hash_str = get_hash_str
268+
elif storage_backend == "nixl":
269+
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
270+
271+
self.storage_backend = HiCacheNixl()
272+
self.get_hash_str = get_hash_str
268273
elif storage_backend == "mooncake":
269274
self.storage_backend = MooncakeStore()
270275
self.get_hash_str = get_hash_str_mooncake
@@ -545,7 +550,11 @@ def terminate_prefetch(self, operation):
545550
def generic_page_transfer(self, operation, batch_size=8):
546551
for i in range(0, len(operation.hash_value), batch_size):
547552
page_hashes = operation.hash_value[i : i + batch_size]
548-
page_data = self.storage_backend.batch_get(page_hashes)
553+
# todo: zero copy
554+
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
555+
page_hashes
556+
)
557+
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
549558
if page_data is None:
550559
logger.warning(
551560
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
@@ -679,7 +688,7 @@ def generic_page_backup(self, operation, batch_size=8):
679688
for i in range(0, len(operation.hash_value), batch_size):
680689
page_hashes = operation.hash_value[i : i + batch_size]
681690
page_data = [
682-
self.mem_pool_host.get_flat_data_pages(
691+
self.mem_pool_host.get_flat_data_page(
683692
operation.host_indices[j * self.page_size]
684693
)
685694
for j in range(i, i + len(page_hashes))

python/sglang/srt/mem_cache/hicache_storage.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,22 @@ def get(
123123
key = self._get_suffixed_key(key)
124124
tensor_path = os.path.join(self.file_path, f"{key}.bin")
125125
try:
126-
# todo: fixing the target_location logic to enable in-place loading
127-
loaded_tensor = torch.load(tensor_path)
128-
if isinstance(loaded_tensor, torch.Tensor):
129-
return loaded_tensor
126+
if target_location is not None:
127+
# Load directly into target_location's memory buffer
128+
with open(tensor_path, "rb") as f:
129+
target_location.set_(
130+
torch.frombuffer(f.read(), dtype=target_location.dtype)
131+
.reshape(target_location.shape)
132+
.storage()
133+
)
134+
return target_location
130135
else:
131-
logger.error(f"Loaded data for key {key} is not a tensor.")
132-
return None
136+
loaded_tensor = torch.load(tensor_path)
137+
if isinstance(loaded_tensor, torch.Tensor):
138+
return loaded_tensor
139+
else:
140+
logger.error(f"Loaded data for key {key} is not a tensor.")
141+
return None
133142
except FileNotFoundError:
134143
return None
135144

python/sglang/srt/mem_cache/memory_pool_host.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ def get_flat_data_page(self, index) -> torch.Tensor:
105105
"""
106106
raise NotImplementedError()
107107

108+
@abc.abstractmethod
109+
def get_dummy_flat_data_page(self) -> torch.Tensor:
110+
"""
111+
Get a dummy flat data page from the host memory pool.
112+
This is used for prefetching or initializing empty pages.
113+
"""
114+
raise NotImplementedError()
115+
108116
@abc.abstractmethod
109117
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
110118
"""
@@ -256,6 +264,14 @@ def init_kv_buffer(self):
256264
def get_flat_data_page(self, index) -> torch.Tensor:
257265
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
258266

267+
def get_dummy_flat_data_page(self) -> torch.Tensor:
268+
return torch.zeros(
269+
(2, self.layer_num, self.page_size, self.head_num, self.head_dim),
270+
dtype=self.dtype,
271+
device=self.device,
272+
pin_memory=self.pin_memory,
273+
).flatten()
274+
259275
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
260276
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
261277
2,
@@ -355,6 +371,19 @@ def init_kv_buffer(self):
355371
def get_flat_data_page(self, index) -> torch.Tensor:
356372
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
357373

374+
def get_dummy_flat_data_page(self) -> torch.Tensor:
375+
return torch.zeros(
376+
(
377+
self.layer_num,
378+
self.page_size,
379+
1,
380+
self.kv_lora_rank + self.qk_rope_head_dim,
381+
),
382+
dtype=self.dtype,
383+
device=self.device,
384+
pin_memory=self.pin_memory,
385+
).flatten()
386+
358387
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
359388
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
360389
self.layer_num,
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# NIXL Integration for HiCache
2+
3+
This directory contains the **NIXL (NVIDIA Inference Xfer Library)** integration for **HiCache**, enabling high-performance storage across multiple backends.
4+
5+
NIXL provides a unified API for accessing various storage plugins, including but not limited to:
6+
7+
- **Deepseek's 3FS APIs** for high-throughput file operations
8+
- **GPU Direct Storage (GDS)** for direct data movement between storage and GPU memory, bypassing CPU memory copies
9+
- **Amazon S3-compatible object storage** for key-value access patterns
10+
11+
Additional backend integrations are planned for future releases.
12+
13+
## NIXL Resources
14+
15+
- **Project Repository**: [NIXL on GitHub](https://github.com/ai-dynamo/nixl)
16+
- **Documentation**: [NIXL Documentation](https://github.com/ai-dynamo/nixl/tree/main/docs)
17+
18+
## Overview
19+
20+
The NIXL integration consists of two main files:
21+
22+
- **`hicache_nixl.py`** - Main HiCache storage connector using NIXL
23+
- **`nixl_utils.py`** - Utility classes for backend selection, registration, and file management
24+
25+
## Components
26+
27+
### HiCacheNixl
28+
The main storage connector that provides:
29+
- Single and batch tensor set/get operations
30+
- Automatic backend selection (3FS > POSIX > GDS_MT > GDS > OBJ)
31+
- High-performance file-based (or) object based storage access using NIXL
32+
33+
### NixlUtils
34+
Consolidated utility classes:
35+
- **NixlBackendSelection** - Handles backend selection and creation
36+
- **NixlRegistration** - Manages memory registration for tensors, files and objects
37+
- **NixlFileManager** - Handles file system operations and NIXL tuple creation
38+
39+
## Running Unit Tests
40+
41+
### Prerequisites
42+
- NIXL library installed and available (latest main required for supporting object query)
43+
- PyTorch installed
44+
- Python 3.8+
45+
46+
### Unit tests from Project root
47+
Navigate to the project root directory (`/path/to/sglang`) and run:
48+
49+
#### Run all NIXL tests:
50+
```bash
51+
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -o asyncio_mode=strict
52+
```
53+
54+
#### Run with verbose output:
55+
```bash
56+
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -o asyncio_mode=strict
57+
```
58+
59+
Note: The `-v` flag provides more detailed output, showing each test case name and its result.
60+
61+
#### Run a specific test:
62+
```bash
63+
PYTHONPATH=. python -m pytest test/srt/test_hicache_nixl_storage.py -v -k test_single_set_get -o asyncio_mode=strict
64+
```
65+
66+
### From Tests Directory
67+
Navigate to the tests directory and run:
68+
69+
```bash
70+
cd test/srt
71+
PYTHONPATH=../.. python -m pytest test_hicache_nixl_storage.py -o asyncio_mode=strict
72+
```
73+
Note: The `-o asyncio_mode=strict` flag is added to suppress warnings about asyncio configuration. This is not required for test functionality but provides cleaner output.
74+
75+
## Test Coverage
76+
77+
Tests for this integration, a test suite can be found at `test_hicache_nixl_storage.py` which covers:
78+
79+
### HiCache Integration Tests (4 tests)
80+
- Single tensor set/get operations
81+
- Batch tensor set/get operations
82+
- Mixed single and batch operations
83+
- Data integrity for various tensor types
84+
85+
### File Management Tests (5 tests)
86+
- Basic file operations
87+
- NIXL tuple creation
88+
- Error handling in file operations
89+
90+
### Registration Tests (2 tests)
91+
- Tensor registration with memory type detection
92+
- File registration using NIXL tuples
93+
94+
## Expected Output
95+
96+
When tests run successfully, you should see:
97+
- NIXL agent initialization messages
98+
- Backend selection messages (e.g., "Backend POSIX was instantiated")
99+
- Test results with "ok" for passed tests
100+
- Summary showing "Ran X tests in Y seconds" and "OK"
101+
102+
## Troubleshooting
103+
104+
### Import Errors
105+
If you encounter `ModuleNotFoundError`, ensure:
106+
- You're running from the correct directory
107+
- `PYTHONPATH` is set correctly
108+
- NIXL library is properly installed
109+
110+
### NIXL Errors
111+
If NIXL operations fail:
112+
- Check that NIXL is properly installed
113+
- Verify that required plugins are available
114+
- Ensure file permissions are correct for test directories
115+
116+
## File Structure
117+
118+
```
119+
python/sglang/srt/mem_cache/nixl/
120+
├── hicache_nixl.py # Main HiCache storage connector
121+
├── nixl_utils.py # All NIXL utility classes
122+
├── README.md # This file
123+
└── tests/
124+
└── test_nixl_unified.py # All tests in one file
125+
```
126+
127+
## Dependencies
128+
129+
- **NIXL**: NVIDIA Inference Xfer Library (version 0.4 or later)
130+
- Required plugins: POSIX (minimum), 3FS/GDS (optional for better performance)
131+
- See [NIXL Installation Guide](https://github.com/ai-dynamo/nixl/blob/main/README.md#installation)
132+
- **PyTorch**: For tensor operations (version 1.8 or later)
133+
- **Python 3.8+**: For type hints and modern features
134+
135+
## Supported Features
136+
137+
### Memory Types
138+
- **Tensor side**: multi-dimensional tensors of all numeric types (int32, int64, float32, float64) are supported.
139+
- Tensors can be on CPU or GPU (as long as a GPU capable backend such as GDS_MT is available).
140+
- Currently each tensor is mapped to a file or key, but it can be extended to support multiple keys per file or key.
141+
142+
- **Storage side**: file and object are supported through their relevant backends (e.g., 3FS or OBJ).
143+
144+
### Backend Priority
145+
146+
The NIXL backend selection follows this priority order:
147+
1. **3FS** - Highest performance (if available)
148+
- Best for high-throughput file operations using Deepseek 3FS APIs
149+
2. **POSIX** - Standard file I/O (fallback)
150+
- Universal compatibility
151+
- Good for development and testing - Leverages both libaio/liburing
152+
3. **GDS_MT** - Multi-threaded GDS (if available)
153+
- Optimized for concurrent operations
154+
- Supports GPU Direct storage with multiple light weight threads
155+
4. **GDS** - GPU Direct Storage (if available)
156+
- Direct GPU-storage data path
157+
- Best for filesystems benefiting from batch operations and smaller IOs.
158+
5. **OBJ** - Amazon S3 based Object Storage
159+
- Key-value based storage
160+
The system automatically selects the best available backend, with POSIX as the default fallback.
161+
162+
## Note
163+
164+
This is v0 of the NIXL connector. Future versions will focus on further performance optimizations such as memory pre-registration (pre-allocating and registering memory buffers to reduce registration overhead during transfers) and block merging (combining related blocks as offsets within the same file to reduce file operations and improve throughput). These optimizations require changes at a higher layer, as the current HiCache API doesn't expose information like block relationships or hash patterns that would enable these optimizations.

0 commit comments

Comments
 (0)
0