[Feat]: Support for CPU and Non-Tensor Data Storage of Yuanrong#107
Conversation
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the WalkthroughThis pull request introduces a partition-based data management system replacing the previous global-step model, adds extensible sampler abstractions for flexible data consumption strategies, implements a pluggable key-value storage backend architecture with support for multiple storage implementations (Yuanrong, KV stores), and updates APIs, documentation, and test coverage to reflect these foundational changes. Changes
Sequence DiagramsequenceDiagram
participant Client as TransferQueueClient
participant Controller as TransferQueueController
participant Sampler as BaseSampler
participant Storage as StorageManager
rect rgb(200, 230, 255)
Note over Client,Storage: Data Insertion Flow (partition_id="train_0")
Client->>Controller: put(data, partition_id="train_0")
Controller->>Controller: async_get_meta(partition_id)
Controller->>Controller: allocate_indexes via PartitionIndexManager
Controller->>Storage: notify_data_update(partition_id, fields, indexes, dtypes, shapes)
Storage->>Storage: put_data (KV serialization)
end
rect rgb(230, 200, 255)
Note over Client,Storage: Data Retrieval Flow (sampling)
Client->>Controller: get_meta(data_fields, batch_size, partition_id)
Controller->>Controller: get_partition(partition_id)
Controller->>Controller: scan_data_status (find ready indexes)
Controller->>Sampler: sample(ready_indexes, batch_size)
Sampler-->>Controller: (sampled_indexes, consumed_indexes)
Controller->>Controller: generate_batch_meta from sampled indexes
Controller-->>Client: BatchMeta
Client->>Storage: get_data(metadata)
Storage->>Storage: get_data (KV retrieval + merge)
Storage-->>Client: TensorDict
end
rect rgb(200, 255, 230)
Note over Client,Storage: Cleanup Flow
Client->>Controller: clear(partition_id="train_0")
Controller->>Storage: clear_data(partition_id)
Storage->>Storage: Remove KV entries
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Poem
Note 🎁 Summarized by CodeRabbit FreeYour organization is on the Free plan. CodeRabbit will generate a high-level summary and a walkthrough for each pull request. For a comprehensive line-by-line review, please upgrade your subscription to CodeRabbit Pro by visiting https://app.coderabbit.ai/login. Comment |
| "This results in the inability to quickly store and retrieve tensors on the NPU side," | ||
| "which may affect performance." | ||
| ) | ||
| elif not torch.npu.is_available(): |
There was a problem hiding this comment.
this will lead to AttributeError when torch_npu is not available
There was a problem hiding this comment.
if package torch_npu is not available, TORCH_NPU_IMPORTED will be False, codes in line:53 don't run.
| def _create_empty_tensorlist(self, shapes, dtypes): | ||
| self._npu_ds_client = None | ||
| self._cpu_ds_client = None | ||
| self.DS_CLIENT_KEYS_LIMIT = 1000 |
There was a problem hiding this comment.
set as environment variable?
| def npu_ds_client_is_available(self): | ||
| return False | ||
| # return self._npu_ds_client is not None |
There was a problem hiding this comment.
npu_ds_client_is_available should be a boolean expression, I think it's better to represent it using a method.
| total_count = (len(keys) + self.DS_CLIENT_KEYS_LIMIT - 1) // self.DS_CLIENT_KEYS_LIMIT | ||
| for i in range(total_count): | ||
| start_idx = self.DS_CLIENT_KEYS_LIMIT * i | ||
| end_idx = min(self.DS_CLIENT_KEYS_LIMIT * (i + 1), len(keys)) | ||
| self._batch_put(keys[start_idx:end_idx], values[start_idx:end_idx]) |
There was a problem hiding this comment.
btw, why there is a self.DS_CLIENT_KEYS_LIMIT? Is it related to performance?
There was a problem hiding this comment.
if we perform self._batch_put in parallel, will the performance be better?
There was a problem hiding this comment.
btw, why there is a self.DS_CLIENT_KEYS_LIMIT? Is it related to performance?
It's a constraint of the keys in the data system.
There was a problem hiding this comment.
I use asyncio to product self._batch_put in parallel like:
async def put(self, keys: list[str], values: list[Any]):
if not isinstance(keys, list) or not isinstance(values, list):
raise ValueError("keys and values must be lists")
if len(keys) != len(values):
raise ValueError("Number of keys must match number of values")
total_count = (len(keys) + DS_CLIENT_KEYS_LIMIT - 1) // DS_CLIENT_KEYS_LIMIT
tasks = []
for i in range(total_count):
start_idx = DS_CLIENT_KEYS_LIMIT * i
end_idx = min(DS_CLIENT_KEYS_LIMIT * (i + 1), len(keys))
tasks.append(self._batch_put(keys[start_idx:end_idx], values[start_idx:end_idx]))
await asyncio.gather(*tasks)The results of synchronous and asynchronous operations about cpu_tensors are very similar. The reason may be that DSTensorClient and KVClient is Synchronous blocking.
| def close(self) -> None: | ||
| """Close all ZMQ sockets and context to prevent resource leaks.""" | ||
| for sock in (self.controller_handshake_socket, self.data_status_update_socket): | ||
| try: | ||
| if sock and not sock.closed: | ||
| sock.close(linger=0) | ||
| except Exception as e: | ||
| logger.error(f"[{self.storage_manager_id}]: Error closing socket {sock}: {str(e)}") | ||
|
|
||
| try: | ||
| if self.zmq_context: | ||
| self.zmq_context.term() | ||
| except Exception as e: | ||
| logger.error(f"[{self.storage_manager_id}]: Error terminating zmq_context: {str(e)}") | ||
|
|
||
| def __del__(self): | ||
| """Destructor to ensure resources are cleaned up.""" | ||
| try: | ||
| self.close() | ||
| except Exception as e: | ||
| logger.error(f"[{self.storage_manager_id}]: Exception during __del__: {str(e)}") | ||
|
|
There was a problem hiding this comment.
these codes are required, otherwise in verl the code might break when switching from rollout model to train mode, or vllm goes to sleep mode
| logger.info("Missing 'client_name in config, using default value('Yuanrong')") | ||
| config["client_name"] = "Yuanrong" | ||
| elif client_name != "Yuanrong": | ||
| raise ValueError("Invalid 'client_name' in config") |
There was a problem hiding this comment.
| raise ValueError("Invalid 'client_name' in config") | |
| raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'Yuanrong' for YuanrongStorageManager".) |
|
|
||
|
|
||
| # TODO: DSTensorClient.dev_mget has wrong behavior: it may require stricter environment to execute | ||
| @StorageClientFactory.register("Yuanrong") |
There was a problem hiding this comment.
| @StorageClientFactory.register("Yuanrong") | |
| @StorageClientFactory.register("YRStorageClient") |
| merged_data = {} | ||
| for field, data_list in grouped_data.items(): | ||
| if all(isinstance(item, torch.Tensor) for item in data_list): | ||
| try: | ||
| merged_data[field] = torch.stack(data_list) | ||
| except RuntimeError: | ||
| try: | ||
| # Fallback to nested tensor if shapes are irregular | ||
| merged_data[field] = torch.nested.as_nested_tensor(data_list) | ||
| except Exception: | ||
| merged_data[field] = NonTensorStack(*data_list) | ||
| else: | ||
| merged_data[field] = NonTensorStack(*data_list) | ||
|
|
||
| return TensorDict(merged_data, batch_size=len(global_indexes)) |
There was a problem hiding this comment.
These codes are useful in many cases. I believe we'd better put it into utils to serve as a general function.
E.g.,
https://github.com/TransferQueue/TransferQueue/blob/dev/transfer_queue/metadata.py#L311
https://github.com/TransferQueue/TransferQueue/blob/dev/transfer_queue/storage/managers/simple_backend_manager.py#L278
| for field in field_names: | ||
| for _ in range(len(global_indexes)): | ||
| merged_data[field].append(values[value_idx]) | ||
| grouped_data[field].append(values[value_idx]) | ||
| value_idx += 1 |
There was a problem hiding this comment.
add TODO: performance optimize

Description
The current
YRStorageClientonly supports storage of NPU tensors and cannot handle CPU tensors or non-tensor data (e.g., bool, str, list, etc.). To improve its generality and system flexibility, this PR refactors theYRStorageClientintransfer_queue/storage/clients/yuanrong_client.pyand updates related storagemanagers to uniformly support the following data types:torch.Tensoron CPU)Key Changes
Introduced a dedicated
_cpu_ds_clientinYRStorageClientto handle all non-NPU data, including CPU tensors and generic Python objects.Removed the
device_idparameter fromYRStorageClient.__init__, decoupling the client from explicit device binding.Optimized the
put/get/clearinterfaces to efficiently support large-volume data operations.Adapted
KVStorageManagerandYuanrongStorageManagerto correctly interface with the newYRStorageClientAPI.Files Modified
storage/clients/yuanrong_client.pystorage/managers/base.pystorage/managers/yuanrong_manager.pyPerformance test
Tensor.shape=[1024]Codes (Click) : Test performance of three methods `get/put/clear`
Integration test
Modify the code in
async_demo.pyas follows:Codes (Click) : Integration test with the entire TransferQueue
Note: If you want to use the NPU and achieve faster tensor access speed, please install
torch_npuand specify the required resources inray.remote:TODO
YRStorageClient.