diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7e9a6af23..3ef86435d 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -51,6 +51,45 @@ from ._logger import set_verbose from ._utils import suppress_stdout_stderr +_rpc_devices_registered = False +from llama_cpp.llama_cpp import _lib # now _lib holds the shared library handle + +def add_rpc_devices(servers: str) -> None: + """ + Register RPC devices with the llama backend using the provided comma‐separated string. + """ + global _rpc_devices_registered + if _rpc_devices_registered: + return # Already registered, so do nothing + lib = _lib # use the imported shared library handle + # Bind ggml_backend_reg_by_name + lib.ggml_backend_reg_by_name.argtypes = [ctypes.c_char_p] + lib.ggml_backend_reg_by_name.restype = ctypes.c_void_p + rpc_reg = lib.ggml_backend_reg_by_name(b"RPC") + if not rpc_reg: + raise ValueError("failed to find RPC backend") + # Bind ggml_backend_reg_get_proc_address + lib.ggml_backend_reg_get_proc_address.argtypes = [ctypes.c_void_p, ctypes.c_char_p] + lib.ggml_backend_reg_get_proc_address.restype = ctypes.c_void_p + rpc_add_fn_ptr = lib.ggml_backend_reg_get_proc_address(rpc_reg, b"ggml_backend_rpc_add_device") + if not rpc_add_fn_ptr: + raise ValueError("failed to find RPC device add function") + # Create a callable from the function pointer: returns a void pointer given a char* + PROTOTYPE = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.c_char_p) + rpc_add_fn = PROTOTYPE(rpc_add_fn_ptr) + # Bind ggml_backend_device_register + lib.ggml_backend_device_register.argtypes = [ctypes.c_void_p] + lib.ggml_backend_device_register.restype = None + # For each server in the comma-separated list, register the device. + for server in servers.split(','): + server = server.strip().encode("utf-8") + dev = rpc_add_fn(server) + if dev: + lib.ggml_backend_device_register(dev) + else: + raise ValueError(f"failed to register RPC device for server: {server.decode('utf-8')}") + _rpc_devices_registered = True + class Llama: """High-level Python wrapper for a llama.cpp model.""" @@ -227,7 +266,8 @@ def __init__( self.model_params.split_mode = split_mode self.model_params.main_gpu = main_gpu if rpc_servers is not None: - self.model_params.rpc_servers = rpc_servers.encode("utf-8") + # self.model_params.rpc_srvers = rpc_servers.encode("utf-8") # this is not working after @667d72846c06b2cf4f7c8a4265e210991a49706b + add_rpc_devices(rpc_servers) self._rpc_servers = rpc_servers else: self._rpc_servers = None