[Feat]: Add general key-value storage interface (KVStorageManager) & yuanrong storage backend#96
Conversation
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThe PR introduces a key-value storage backend for the transfer queue system. It adds a StorageClientFactory for extensible storage client registration and instantiation, implements a YRStorageClient for remote Yuanrong tensor operations, and introduces KVStorageManager for metadata-to-key mapping and tensor serialization. Supporting tests validate factory behavior, key generation, value flattening, and bidirectional data reconstruction. Changes
Sequence DiagramsequenceDiagram
participant App
participant KVStorageManager
participant StorageClientFactory
participant YRStorageClient
participant YuanrongDS as Yuanrong DS
App->>StorageClientFactory: create("Yuanrong", config)
StorageClientFactory->>YRStorageClient: instantiate with config
YRStorageClient->>YuanrongDS: connect (host, port, device_id)
StorageClientFactory-->>App: YRStorageClient instance
App->>KVStorageManager: put_data(TensorDict, BatchMeta)
KVStorageManager->>KVStorageManager: _generate_yr_keys(BatchMeta)
KVStorageManager->>KVStorageManager: _generate_yr_values(TensorDict)
KVStorageManager->>YRStorageClient: put(keys, values)
YRStorageClient->>YuanrongDS: mset(keys, values)
YuanrongDS-->>YRStorageClient: ✓ stored
App->>KVStorageManager: get_data(BatchMeta)
KVStorageManager->>KVStorageManager: _generate_yr_keys(BatchMeta)
KVStorageManager->>YRStorageClient: get(keys, shapes, dtypes)
YRStorageClient->>YuanrongDS: mget(keys) [2000ms timeout]
YuanrongDS-->>YRStorageClient: tensor values
KVStorageManager->>KVStorageManager: _merge_kv_to_dict(BatchMeta, values)
KVStorageManager-->>App: reconstructed TensorDict
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related issues
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Comment |
| """ | ||
|
|
||
| # Class variable: maps client names to their corresponding classes | ||
| _registry: dict[str, Type['StorageClientFactory']] = {} |
There was a problem hiding this comment.
Better to provide a base class TransferQueueStorageClient for storage client?
| return decorator | ||
|
|
||
| @classmethod | ||
| def create(cls, name: str, config: dict) -> 'StorageClientFactory': |
There was a problem hiding this comment.
param name refer to TransferQueueStorageManagerFactory (name->client_type)
| raise ValueError(f"Unknown StorageClient: {name}") | ||
| return cls._registry[name](config) | ||
|
|
||
| @abstractmethod |
There was a problem hiding this comment.
As above comment, better to provide a base abstract class
| All tensors must reside on NPU device. | ||
| """ | ||
|
|
||
| def __init__(self, cfg: dict[str, Any]): |
There was a problem hiding this comment.
cfg -> config so we can align with other parts of the code~
| raise ValueError("Number of keys must match number of values") | ||
|
|
||
| # 约束:传入的key的数量不能超过1万。&Tensor的地址空间必须连续。 | ||
| assert len(keys) <= 10000 |
There was a problem hiding this comment.
better raise an error; how to make sure the address space is continues?
There was a problem hiding this comment.
better raise an error; how to make sure the address space is continues?
Yuanrong data system imposes a limit on the number of keys for a single query; this will be modified later to use multiple queries. Additionally, empty tensors are created by KVStorageManager._create_empty_tensorlist and should therefore be contiguous. If the additional performance overhead is acceptable, a check for tensor contiguity can be added in the future.
|
|
||
| import zmq | ||
| from tensordict import TensorDict | ||
| import torch |
There was a problem hiding this comment.
extra space; please use
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always to auto fix these issues
| self.storage_client = StorageClientFactory.create(client_name, config) | ||
|
|
||
| @staticmethod | ||
| def _generate_yr_keys(metadata: BatchMeta) -> list[str]: |
There was a problem hiding this comment.
it can be a general function. We can just call it _generate_keys. same for following function names
| return [ | ||
| f'{index}@{field}' | ||
| for field in sorted(metadata.field_names) | ||
| for index in metadata.global_indexes | ||
| ] |
There was a problem hiding this comment.
Maybe we can use itertools to speed up this process? there may be lots of keys.
There was a problem hiding this comment.
sorted_fields = sorted(metadata.field_names)
indexes = metadata.global_indexes
index_strs = [str(idx) for idx in indexes]
return [
f"{idx_str}@{field}"
for field, idx_str in itertools.product(sorted_fields, index_strs)
]There was a problem hiding this comment.
And is this key order can get better performance than the other order? Typically rows can be much more than columns
There was a problem hiding this comment.
And is this key order can get better performance than the other order? Typically rows can be much more than columns
Generally speaking, for two unrelated lists, I think placing the shorter list in the outer loop is better, as it means fewer conditional checks and fewer pointer jumps.
| ] | ||
|
|
||
| @staticmethod | ||
| def _generate_yr_values(data: TensorDict) -> list[Tensor]: |
| [data[field_a][0], data[field_a][1], data[field_a][2], ..., data[field_b][0], ...] | ||
| """ | ||
| for v in data.values(): | ||
| if not torch.is_tensor(v): |
There was a problem hiding this comment.
Add a TODO here. We need to support complex data types (NonTensorStack/NonTensorData/NestedTensor)
| ] | ||
|
|
||
| @staticmethod | ||
| def _merge_kv_to_dict(metadata: BatchMeta, values: list[Tensor]) -> TensorDict: |
| @@ -0,0 +1,15 @@ | |||
| from .base import KVStorageManager | |||
There was a problem hiding this comment.
suggest use from transfer_queue.storage.managers.base import KVStorageManager
| @@ -0,0 +1,113 @@ | |||
| from .factory import StorageClientFactory | |||
There was a problem hiding this comment.
suggest use from transfer_queue.storage.clients.base import TransferQueueStorageClient
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Pull Request Overview
This PR adds support for a key-value storage backend (YuanRong DataSystem) to the transfer queue system. It introduces a factory pattern for storage clients and implements KV-based storage management for tensor data.
- Introduces
StorageClientFactorywith registration pattern for storage clients - Implements
YRStorageClientfor YuanRong DataSystem with NPU tensor support - Adds
KVStorageManagerto handle metadata-to-KV mapping and tensor operations - Includes comprehensive test coverage for factory, client, and manager functionality
Reviewed Changes
Copilot reviewed 6 out of 7 changed files in this pull request and generated 17 comments.
Show a summary per file
| File | Description |
|---|---|
| transfer_queue/storage/clients/factory.py | Implements abstract factory base class with decorator-based client registration |
| transfer_queue/storage/clients/yuanrong_client.py | Adds YuanRong DataSystem client implementation with NPU tensor operations |
| transfer_queue/storage/managers/base.py | Adds KVStorageManager class with metadata-to-KV conversion logic and imports torch |
| transfer_queue/storage/managers/yuanrong_manager.py | Implements YuanrongStorageManager with config validation |
| tests/test_storage_client_factory.py | Adds unit tests for storage client factory and tensor list creation |
| tests/test_kv_storage_manager.py | Adds unit tests for KV storage manager key/value generation and merging |
| .gitignore | Adds PyCharm IDE files to ignore list |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| import zmq | ||
| from tensordict import TensorDict | ||
| import torch |
There was a problem hiding this comment.
Extra space between 'import' and 'torch'. Should be 'import torch'.
| import torch | |
| import torch |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from abc import ABC,abstractmethod |
There was a problem hiding this comment.
Missing space after comma in import statement. Should be 'ABC, abstractmethod'.
| from abc import ABC,abstractmethod | |
| from abc import ABC, abstractmethod |
| """ | ||
| super().__init__(config) | ||
| client_name = config.get("client_name", "Yuanrong") | ||
| self.storage_client = StorageClientFactory.create(client_name, config) |
There was a problem hiding this comment.
Missing import for 'StorageClientFactory'. The class is used but not imported from 'transfer_queue.storage.clients.factory'.
| raise ValueError('Yuanrong DataSystem') | ||
| if dtypes is None: | ||
| raise ValueError('Yuanrong DataSystem') |
There was a problem hiding this comment.
Error messages at lines 90 and 92 are incomplete. They should clearly state what is missing, e.g., 'shapes parameter is required for Yuanrong DataSystem' and 'dtypes parameter is required for Yuanrong DataSystem'.
| raise ValueError('Yuanrong DataSystem') | |
| if dtypes is None: | |
| raise ValueError('Yuanrong DataSystem') | |
| raise ValueError('shapes parameter is required for Yuanrong DataSystem') | |
| if dtypes is None: | |
| raise ValueError('dtypes parameter is required for Yuanrong DataSystem') |
There was a problem hiding this comment.
had upated more prompts
| import torch_npu | ||
|
|
||
|
|
||
| # TODO: DSTensorClient.dev_mget has wrong |
There was a problem hiding this comment.
Incomplete TODO comment. Should clarify what is wrong with 'DSTensorClient.dev_mget', e.g., 'TODO: DSTensorClient.dev_mget has wrong behavior' or 'TODO: DSTensorClient.dev_mget has wrong return type'.
| # TODO: DSTensorClient.dev_mget has wrong | |
| # TODO: DSTensorClient.dev_mget has wrong behavior: it does not return tensors in the expected order when keys are missing. |
| client=StorageClientFactory.create("Yuanrong", self.cfg) | ||
|
|
There was a problem hiding this comment.
Variable client is not used.
| client=StorageClientFactory.create("Yuanrong", self.cfg) |
|
|
||
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
| from transfer_queue.storage.managers.base import KVStorageManager | ||
| from transfer_queue.storage.managers.yuanrong_manager import YuanrongStorageManager |
There was a problem hiding this comment.
Import of 'YuanrongStorageManager' is not used.
| from transfer_queue.storage.managers.yuanrong_manager import YuanrongStorageManager |
| from .factory import StorageClientFactory | ||
| from typing import Any | ||
| from torch import Tensor | ||
| from tensordict import TensorDict |
There was a problem hiding this comment.
Import of 'TensorDict' is not used.
| from tensordict import TensorDict |
| from tensordict import TensorDict | ||
| import torch | ||
| import datasystem | ||
| import torch_npu |
There was a problem hiding this comment.
Import of 'torch_npu' is not used.
| import torch_npu |
There was a problem hiding this comment.
torch_npu is neccessary for yuanrong
|
|
||
| try: | ||
| from .yuanrong_client import YRStorageClient | ||
| except ImportError: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except ImportError: | |
| except ImportError: | |
| # YRStorageClient is optional and may not be available in all environments. |
There was a problem hiding this comment.
it try to import YRStorageClient for registering
There was a problem hiding this comment.
Actionable comments posted: 7
♻️ Duplicate comments (4)
transfer_queue/storage/managers/base.py (4)
24-24: Formatting: Extra space in import statement.
289-306: Consider generalizing method name and using itertools.
309-334: Consider generalizing method name and adding TODO for complex types.
337-337: Consider renaming to_merge_kv_to_tensordict.
🧹 Nitpick comments (2)
.gitignore (1)
184-189: Consolidate duplicate PyCharm sections.The
.gitignorefile now contains two PyCharm-related sections: a commented-out one at lines 184–189 and an active one at lines 218–219. This redundancy creates maintenance confusion and violates the DRY principle.Consider consolidating by removing the old commented section and keeping the new active section, or by uncommenting the existing section and removing the duplicate at the end of the file.
If you prefer to keep the new active section at the end, apply this diff to clean up:
# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. # AbstraThen remove the duplicate section at lines 218–219.
Also applies to: 218-219
transfer_queue/storage/managers/base.py (1)
384-384: Add return type annotation for consistency.The method signature is missing a return type annotation, though the docstring documents it. Add the type hint for consistency with Python typing best practices.
@staticmethod - def _get_shape_type_list(metadata: BatchMeta): + def _get_shape_type_list(metadata: BatchMeta) -> tuple[list[torch.Size], list[torch.dtype]]:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
.gitignore(1 hunks)tests/test_kv_storage_manager.py(1 hunks)tests/test_storage_client_factory.py(1 hunks)transfer_queue/storage/clients/factory.py(1 hunks)transfer_queue/storage/clients/yuanrong_client.py(1 hunks)transfer_queue/storage/managers/base.py(2 hunks)transfer_queue/storage/managers/yuanrong_manager.py(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: pre-commit (3.10)
- GitHub Check: build (3.11)
| # 约束:传入的key的数量不能超过1万。&Tensor的地址空间必须连续。 | ||
| assert len(keys) <= 10000 | ||
|
|
||
| for value in values: | ||
| if not isinstance(value, torch.Tensor): | ||
| raise ValueError(f"Expected torch.Tensor, got {type(value)}") | ||
| if value.device.type != 'npu': | ||
| raise ValueError(f"Tensor is on {value.device}, not on NPU") | ||
|
|
There was a problem hiding this comment.
Replace the assertion with explicit validation.
assert len(keys) <= 10000 is skipped when Python runs with optimizations (python -O), so the critical safeguard on key count silently disappears. That leaves dev_mset exposed to oversized batches, violating the documented constraint. Please convert this to an explicit runtime check that raises a ValueError.
- # 约束:传入的key的数量不能超过1万。&Tensor的地址空间必须连续。
- assert len(keys) <= 10000
+ # 约束:传入的key的数量不能超过1万。&Tensor的地址空间必须连续。
+ if len(keys) > 10000:
+ raise ValueError("Number of keys cannot exceed 10_000")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # 约束:传入的key的数量不能超过1万。&Tensor的地址空间必须连续。 | |
| assert len(keys) <= 10000 | |
| for value in values: | |
| if not isinstance(value, torch.Tensor): | |
| raise ValueError(f"Expected torch.Tensor, got {type(value)}") | |
| if value.device.type != 'npu': | |
| raise ValueError(f"Tensor is on {value.device}, not on NPU") | |
| # 约束:传入的key的数量不能超过1万。&Tensor的地址空间必须连续。 | |
| if len(keys) > 10000: | |
| raise ValueError("Number of keys cannot exceed 10_000") | |
| for value in values: | |
| if not isinstance(value, torch.Tensor): | |
| raise ValueError(f"Expected torch.Tensor, got {type(value)}") | |
| if value.device.type != 'npu': | |
| raise ValueError(f"Tensor is on {value.device}, not on NPU") |
🤖 Prompt for AI Agents
In transfer_queue/storage/clients/yuanrong_client.py around lines 63 to 71,
replace the assert len(keys) <= 10000 with an explicit runtime validation that
checks if len(keys) > 10000 and raises a ValueError (including the actual key
count and the 10000 limit in the message); keep the subsequent tensor type and
device checks as-is so oversized batches are rejected reliably even when Python
is run with optimizations.
There was a problem hiding this comment.
@dpj135 Thank you for addressing this issue! The explicit validation will now properly enforce the 10,000 key limit even when Python runs with optimizations enabled. 🎉
| if shapes is None: | ||
| raise ValueError('Yuanrong DataSystem') | ||
| if dtypes is None: | ||
| raise ValueError('Yuanrong DataSystem') | ||
| if len(dtypes) != len(shapes): | ||
| raise ValueError('Length of dtypes must equal length of shapes') | ||
|
|
||
| values: list[Tensor] = self._create_empty_tensorlist(shapes=shapes, dtypes=dtypes) | ||
|
|
||
| # 约束:传入的key的数量不能超过1万。&Tensor的地址空间必须连续。 | ||
| # print(f'get_keys: {keys}') | ||
| assert len(keys) <= 10000 | ||
|
|
||
| # Timeout set to 2000ms | ||
| self._ds_client.dev_mget(keys, values, 2000) | ||
| return values |
There was a problem hiding this comment.
Validate key/metadata alignment and enforce the limit without assert.
get never checks that the number of keys matches the supplied shapes/dtypes, so callers can provide mismatched metadata and silently feed _ds_client.dev_mget with too few or too many output buffers. That risks read-overwrite errors or misaligned tensors. Additionally, the key-count guard again relies on assert, which is stripped with python -O. Please add explicit validation for both the length alignment and the 10 000-key limit.
if dtypes is None:
raise ValueError('Yuanrong DataSystem')
if len(dtypes) != len(shapes):
raise ValueError('Length of dtypes must equal length of shapes')
+ if len(keys) != len(shapes):
+ raise ValueError('Number of keys must match provided shapes/dtypes')
+ if len(keys) > 10000:
+ raise ValueError('Number of keys cannot exceed 10_000')
values: list[Tensor] = self._create_empty_tensorlist(shapes=shapes, dtypes=dtypes)
- # 约束:传入的key的数量不能超过1万。&Tensor的地址空间必须连续。
- # print(f'get_keys: {keys}')
- assert len(keys) <= 10000
+ # 约束:传入的key的数量不能超过1万。&Tensor的地址空间必须连续。
+ # print(f'get_keys: {keys}')🤖 Prompt for AI Agents
In transfer_queue/storage/clients/yuanrong_client.py around lines 89-104, the
method currently uses assert for the 10,000-key guard and does not validate that
len(keys) matches the supplied shapes and dtypes, which can cause misaligned
buffers or be skipped under python -O; replace the assert with an explicit check
that raises a ValueError if len(keys) > 10000, and add a validation that
len(keys) == len(shapes) == len(dtypes) (raising ValueError with a clear message
if not) before creating the output tensors and calling _ds_client.dev_mget.
There was a problem hiding this comment.
@dpj135, thank you for addressing the validation issues! The explicit checks will ensure robust error handling in production environments.
| def __init__(self, config: dict[str, Any]): | ||
| """ | ||
| Initialize the KVStorageManager with configuration. | ||
| """ | ||
| super().__init__(config) | ||
| client_name = config.get("client_name", "Yuanrong") | ||
| self.storage_client = StorageClientFactory.create(client_name, config) | ||
|
|
There was a problem hiding this comment.
Missing import causes NameError.
Line 286 references StorageClientFactory but it's not imported. This will cause a runtime error when instantiating KVStorageManager.
Add the missing import at the top of the file:
+from transfer_queue.storage.clients.factory import StorageClientFactoryCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In transfer_queue/storage/managers/base.py around lines 280 to 287, the
constructor uses StorageClientFactory but it's not imported; add the missing
import at the top of the file (e.g., add a line importing StorageClientFactory —
use a relative import like "from ..clients import StorageClientFactory" or the
full package path "from transfer_queue.storage.clients import
StorageClientFactory") so the name is defined when KVStorageManager is
instantiated.
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify the import statement is present in the file
grep -n "from transfer_queue.storage.clients.factory import StorageClientFactory" transfer_queue/storage/managers/base.py
Length of output: 74
@dpj135 Thank you for adding the import! The StorageClientFactory import is now properly included on line 31, which resolves the issue. The KVStorageManager class can now be instantiated without any NameError. ✅
🧬 Code Graph Analysis Results
[transfer_queue/storage/managers/factory.py:26-40]
def register(cls, manager_type: str, manager_cls: type[TransferQueueStorageManager]):
if not issubclass(manager_cls, TransferQueueStorageManager):
raise TypeError(
f"manager_cls {getattr(manager_cls, 'name', repr(manager_cls))} must be "
f"a subclass of TransferQueueStorageManager"
)
cls._registry[manager_type] = manager_cls
End snippet
[transfer_queue/storage/managers/factory.py:35-40]
def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager:
if manager_type not in cls._registry:
raise ValueError(
f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}"
)
return cls._registrymanager_type
End snippet
[transfer_queue/storage/clients/factory.py:35-40]
def register(cls, client_type: str):
"""
Decorator to register a concrete client class with the factory.
Args:
client_type (str): The name used to identify the client
Returns:
Callable: The decorator function that returns the original class
"""
def decorator(client_class: TransferQueueStorageClient) -> TransferQueueStorageClient:
cls._registry[client_type] = client_class
return client_class
return decorator
End snippet
[transfer_queue/storage/clients/factory.py:59-71]
def create(cls, client_type: str, config: dict) -> TransferQueueStorageClient:
"""
Create and return an instance of the storage client by name.
Args:
client_type (str): The registered name of the client
Returns:
StorageClientFactory: An instance of the requested client
Raises:
ValueError: If no client is registered with the given name
"""
if client_type not in cls._registry:
raise ValueError(f"Unknown StorageClient: {client_type}")
return cls._registryclient_type
End snippet
[transfer_queue/utils/zmq_utils.py:31-65]
class ZMQRequestType(ExplicitEnum):
# HANDSHAKE
HANDSHAKE = "HANDSHAKE" # TransferQueueStorageUnit -> TransferQueueController
HANDSHAKE_ACK = "HANDSHAKE_ACK" # TransferQueueController -> TransferQueueStorageUnit
# DATA_OPERATION
GET_DATA = "GET"
PUT_DATA = "PUT"
GET_DATA_RESPONSE = "GET_DATA_RESPONSE"
PUT_DATA_RESPONSE = "PUT_DATA_RESPONSE"
CLEAR_DATA = "CLEAR_DATA"
CLEAR_DATA_RESPONSE = "CLEAR_DATA_RESPONSE"
PUT_GET_OPERATION_ERROR = "PUT_GET_OPERATION_ERROR"
PUT_GET_ERROR = "PUT_GET_ERROR"
PUT_ERROR = "PUT_ERROR"
GET_ERROR = "GET_ERROR"
CLEAR_DATA_ERROR = "CLEAR_DATA_ERROR"
# META_OPERATION
GET_META = "GET_META"
GET_META_RESPONSE = "GET_META_RESPONSE"
GET_CLEAR_META = "GET_CLEAR_META"
GET_CLEAR_META_RESPONSE = "GET_CLEAR_META_RESPONSE"
CLEAR_META = "CLEAR_META"
CLEAR_META_RESPONSE = "CLEAR_META_RESPONSE"
# CHECK_CONSUMPTION
CHECK_CONSUMPTION = "CHECK_CONSUMPTION"
CONSUMPTION_RESPONSE = "CONSUMPTION_RESPONSE"
# NOTIFY_DATA_UPDATE
NOTIFY_DATA_UPDATE = "NOTIFY_DATA_UPDATE"
NOTIFY_DATA_UPDATE_ACK = "NOTIFY_DATA_UPDATE_ACK"
NOTIFY_DATA_UPDATE_ERROR = "NOTIFY_DATA_UPDATE_ERROR"
End snippet
[transfer_queue/utils/zmq_utils.py:68-87]
class ZMQServerInfo:
def init(self, role: TransferQueueRole, id: str, ip: str, ports: dict[str, str]):
self.role = role
self.id = id
self.ip = ip
self.ports = ports
def to_addr(self, port_name: str) -> str:
return f"tcp://{self.ip}:{self.ports[port_name]}"
def to_dict(self):
return {
"role": self.role,
"id": self.id,
"ip": self.ip,
"ports": self.ports,
}
def __str__(self) -> str:
return f"ZMQSocketInfo(role={self.role}, id={self.id}, ip={self.ip}, ports={self.ports})"
End snippet
[transfer_queue/utils/zmq_utils.py:100-114]
def create(
cls,
request_type: ZMQRequestType,
sender_id: str,
body: dict[str, Any],
receiver_id: Optional[str] = None,
) -> "ZMQMessage":
return cls(
request_type=request_type,
sender_id=sender_id,
receiver_id=receiver_id,
body=body,
request_id=str(uuid4().hex[:8]),
timestamp=time.time(),
)
End snippet
[transfer_queue/utils/zmq_utils.py:116-118]
def serialize(self) -> bytes:
"""Using pickle to serialize ZMQMessage objects"""
return pickle.dumps(self)
End snippet
[transfer_queue/utils/zmq_utils.py:121-131]
def deserialize(cls, data: bytes | list[bytes]):
"""Using pickle to deserialize ZMQMessage objects"""
if isinstance(data, list):
# Process multiple byte streams by deserializing each in sequence
result = []
for d in data:
result.append(pickle.loads(d))
return result
else:
# Single byte stream case
return pickle.loads(data)
End snippet
[transfer_queue/metadata.py:136-151]
class BatchMeta:
"""Records the metadata of a batch of data samples."""
samples: list[SampleMeta]
extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
def __post_init__(self):
"""Initialize all computed properties during initialization"""
# Basic properties
object.__setattr__(self, "_size", len(self.samples))
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
# Pre-compute all list properties for better performance
if self.samples:
for idx, sample in enumerate(self.samples):
object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly
object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
# assume all samples have the same fields.
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
else:
object.__setattr__(self, "_global_indexes", [])
object.__setattr__(self, "_field_names", [])
End snippet
[transfer_queue/metadata.py:167-169]
def global_indexes(self) -> list[int]:
"""Get all global indexes in this batch"""
return getattr(self, "_global_indexes", [])
End snippet
[transfer_queue/metadata.py:74-76]
def get_field_by_name(self, name: str) -> Optional[FieldMeta]:
"""Get FieldMeta by field name"""
return self.fields.get(name)
End snippet
[transfer_queue/storage/clients/yuanrong_client.py:41-63]
def put(self, keys: list[str], values: list[Tensor]):
"""
Store tensors to remote storage.
Args:
keys (list): List of string keys
values (list): List of torch.Tensor on NPU
"""
if not isinstance(keys, list) or not isinstance(values, list):
raise ValueError("keys and values must be lists")
if len(keys) != len(values):
raise ValueError("Number of keys must match number of values")
# TODO: Support the situation when the number of keys is greater than 10000
if len(keys) > 10000:
raise NotImplementedError('We will support the number of keys greater than 10000 int the future')
for value in values:
if not isinstance(value, torch.Tensor):
raise ValueError(f"Expected torch.Tensor, got {type(value)}")
if value.device.type != 'npu':
raise ValueError(f"Tensor is on {value.device}, not on NPU")
self._ds_client.dev_mset(keys, values)
End snippet
[transfer_queue/storage/clients/yuanrong_client.py:65-90]
def get(self, keys: list[str], shapes=None, dtypes=None) -> list[Tensor]:
"""
Retrieve tensors from remote storage.
Args:
keys (list): List of keys to fetch
shapes (list): Expected shapes of returned tensors
dtypes (list): Expected dtypes of returned tensors
Returns:
list: List of retrieved NPU tensors
"""
if shapes is None:
raise ValueError('Yuanrong storage client needs Expected shapes of returned tensors')
if dtypes is None:
raise ValueError('Yuanrong storage client needs Expected dtypes of returned tensors')
if len(dtypes) != len(shapes):
raise ValueError('Length of dtypes must equal length of shapes')
values: list[Tensor] = self._create_empty_tensorlist(shapes=shapes, dtypes=dtypes)
# TODO: Support the situation when the number of keys is greater than 10000
if len(keys) > 10000:
raise NotImplementedError('We will support the number of keys greater than 10000 int the future')
# Timeout set to 2000ms
self._ds_client.dev_mget(keys, values, 2000)
return values
End snippet
[transfer_queue/storage/clients/yuanrong_client.py:92-98]
def clear(self, keys: list[str]):
"""
Delete entries from storage by keys.
Args:
keys (list): List of keys to delete
"""
self._ds_client.dev_delete(keys)
End snippet
[transfer_queue/storage/metadata.py:136-151]
class BatchMeta:
"""Records the metadata of a batch of data samples."""
samples: list[SampleMeta]
extra_info: dict[str, Any] = dataclasses.field(default_factory=dict)
def __post_init__(self):
"""Initialize all computed properties during initialization"""
# Basic properties
object.__setattr__(self, "_size", len(self.samples))
object.__setattr__(self, "_is_ready", all(sample.is_ready for sample in self.samples))
# Pre-compute all list properties for better performance
if self.samples:
for idx, sample in enumerate(self.samples):
object.__setattr__(sample, "_batch_index", idx) # Ensure batch_index is set correctly
object.__setattr__(self, "_global_indexes", [sample.global_index for sample in self.samples])
# assume all samples have the same fields.
object.__setattr__(self, "_field_names", sorted(self.samples[0].field_names))
else:
object.__setattr__(self, "_global_indexes", [])
object.__setattr__(self, "_field_names", [])
End snippet
[transfer_queue/storage/managers/base.py:— end —]
Note: The above snippets are provided as-is from the referenced files to help understand interactions with the file transfer_queue/storage/managers/base.py and its related components.
| try: | ||
| tensor_data[field] = torch.stack(tensor_list) | ||
| except RuntimeError as re: | ||
| # Fallback to nested tensor if shapes are irregular | ||
| tensor_data[field] = torch.nested.as_nested_tensor(tensor_list) |
There was a problem hiding this comment.
RuntimeError catch may be too broad.
The except RuntimeError clause assumes any RuntimeError during torch.stack is due to shape mismatches, but other runtime errors (e.g., CUDA errors, device mismatches) could also raise RuntimeError. Consider catching more specific exceptions or adding validation.
Consider validating tensor shapes before stacking, or at minimum logging the exception to aid debugging:
try:
tensor_data[field] = torch.stack(tensor_list)
- except RuntimeError as re:
+ except RuntimeError as e:
+ logger.warning(f"Cannot stack tensors for field {field}, using nested tensor. Error: {e}")
# Fallback to nested tensor if shapes are irregular
tensor_data[field] = torch.nested.as_nested_tensor(tensor_list)Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In transfer_queue/storage/managers/base.py around lines 375 to 379, the current
except RuntimeError: fallback to nested tensor is too broad; validate tensors
before stacking (ensure all tensors have identical .shape and .device), only
attempt torch.stack when validation passes, and if torch.stack still fails catch
the exception, log the exception details, and only use
torch.nested.as_nested_tensor when the failure is confirmed to be due to
irregular shapes/devices; otherwise re-raise or propagate the error so true
runtime (e.g., CUDA) issues aren’t silently swallowed.
There was a problem hiding this comment.
@dpj135, understood. I've noted your decision to skip this suggestion.
| async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: | ||
| keys = self._generate_yr_keys(metadata) | ||
| values = self._generate_yr_values(data) | ||
| self.storage_client.put(keys=keys, values=values) | ||
|
|
||
| async def get_data(self, metadata: BatchMeta) -> TensorDict: | ||
| keys = self._generate_yr_keys(metadata) | ||
| shapes, dtypes = self._get_shape_type_list(metadata) | ||
| values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes) | ||
| return self._merge_kv_to_dict(metadata, values) | ||
|
|
||
| async def clear_data(self, metadata: BatchMeta) -> None: | ||
| keys = self._generate_yr_keys(metadata) | ||
| self.storage_client.clear(keys=keys) No newline at end of file |
There was a problem hiding this comment.
Methods marked async but perform no async operations.
All three methods (put_data, get_data, clear_data) are declared async but call synchronous storage_client methods without using await. This blocks the event loop and defeats the purpose of async/await.
Either:
- If
storage_clientmethods are synchronous, removeasyncfrom these method signatures (but this may break the abstract base class contract). - If
storage_clientmethods should be async, update them to be async and useawaithere. - If synchronous calls are unavoidable, run them in an executor to avoid blocking:
async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
keys = self._generate_yr_keys(metadata)
values = self._generate_yr_values(data)
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self.storage_client.put, keys, values)🛠️ Refactor suggestion | 🟠 Major
Add error handling for storage client operations.
None of the data operations have error handling. If storage_client operations fail (network issues, storage full, etc.), exceptions will propagate unhandled. Add try-except blocks with appropriate logging and error recovery.
Example for put_data:
async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
- keys = self._generate_yr_keys(metadata)
- values = self._generate_yr_values(data)
- self.storage_client.put(keys=keys, values=values)
+ try:
+ keys = self._generate_yr_keys(metadata)
+ values = self._generate_yr_values(data)
+ self.storage_client.put(keys=keys, values=values)
+ except Exception as e:
+ logger.error(f"[{self.storage_manager_id}]: Failed to put data: {e}")
+ raiseApply similar error handling to get_data and clear_data.
Committable suggestion skipped: line range outside the PR's diff.
There was a problem hiding this comment.
@dpj135, understood! I've noted that you're choosing not to implement these suggestions for now. Feel free to reach out if you'd like to discuss them in the future.
(\_/)
( •_•)
/ >💔
| async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None: | ||
| keys = self._generate_yr_keys(metadata) | ||
| values = self._generate_yr_values(data) | ||
| self.storage_client.put(keys=keys, values=values) |
There was a problem hiding this comment.
Missing data update notification after put.
After successfully storing data, put_data should call notify_data_update (inherited from the base class) to inform the controller that new data is available. This is essential for the queue coordination workflow.
Add notification after the put operation:
async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
keys = self._generate_yr_keys(metadata)
values = self._generate_yr_values(data)
self.storage_client.put(keys=keys, values=values)
+
+ # Notify controller of new data
+ dtypes = {idx: {} for idx in metadata.global_indexes}
+ shapes = {idx: {} for idx in metadata.global_indexes}
+ for idx, sample in enumerate(metadata.samples):
+ for field_name in metadata.field_names:
+ field = sample.get_field_by_name(field_name)
+ dtypes[metadata.global_indexes[idx]][field_name] = field.dtype
+ shapes[metadata.global_indexes[idx]][field_name] = field.shape
+
+ await self.notify_data_update(
+ fields=list(metadata.field_names),
+ global_indexes=metadata.global_indexes,
+ dtypes=dtypes,
+ shapes=shapes,
+ )🤖 Prompt for AI Agents
In transfer_queue/storage/managers/base.py around lines 407 to 410, put_data
stores keys/values but does not notify the controller; after successfully
calling self.storage_client.put(keys=keys, values=values) call await
self.notify_data_update(metadata) so the inherited notification is sent to the
controller; ensure the notify call is awaited (and only run after put completes)
so the queue coordination workflow is triggered.
There was a problem hiding this comment.
@dpj135, thank you for implementing the fix! I can confirm that the put_data method now properly calls notify_data_update after storing the data, which ensures the controller is informed about the new data availability. The implementation looks correct, with the dtypes and shapes being properly extracted and passed to the notification method. Great work! ✨
🐰
|
| from torch import Tensor | ||
|
|
||
|
|
||
| class TransferQueueStorageClient(ABC): |
There was a problem hiding this comment.
This interface seems to fit KV storage client. We should call it TransferQueueStorageKVClient. And this class should be placed in a separated file named base.py
There was a problem hiding this comment.
And a general base class is needed
| # self.sm = YuanrongStorageManager(self.cfg) | ||
|
|
||
| def test_generate_keys(self): | ||
| """测试 _generate_keys 生成正确的 key 列表""" |
There was a problem hiding this comment.
Need to translate it into English
| @@ -0,0 +1,95 @@ | |||
| import unittest | |||
|
|
|||
| # from ..clients.factory import StorageClientFactory | |||
There was a problem hiding this comment.
Unnecessary comments should be removed
| # Register storage clients | ||
| try: | ||
| import transfer_queue.storage.clients.yuanrong_client as _ # noqa: F401 | ||
| except ImportError: |
| return [row_data for field in sorted(data.keys()) for row_data in data[field]] | ||
|
|
||
| @staticmethod | ||
| def _merge_kv_to_tensordict(metadata: BatchMeta, values: list[Tensor]) -> TensorDict: |
There was a problem hiding this comment.
Now this function is a little hard to understand. The input params is metadata and values, which are not keys and values.
Can we propose a better calling logic of its upstream function? Or simply change the name of it?
async def get_data(self, metadata: BatchMeta) -> TensorDict:
keys = self._generate_keys(metadata)
shapes, dtypes = self._get_shape_type_list(metadata)
values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes)
return self._merge_kv_to_tensordict(metadata, values)There was a problem hiding this comment.
could we call it _merge_tensors_to_tensordict?
async def get_data(self, metadata: BatchMeta) -> TensorDict:
keys = self._generate_keys(metadata)
shapes, dtypes = self._get_shape_type_list(metadata)
values = self.storage_client.get(keys=keys, shapes=shapes, dtypes=dtypes)
return self._merge_tensors_to_tensordict(metadata, values)| dtypes.append(field.dtype) | ||
| return shapes, dtypes | ||
|
|
||
| # TODO: Test put_data/get_data/clear_data with YuanrongStorageClient |
There was a problem hiding this comment.
Is this TODO still valid? Or it has done
There was a problem hiding this comment.
this todo is still valid. Later, we need to integrate the entire TransferQueue to test what issues may arise when KVStorageManager calls the put/get methods of YRStorageClient.

Summary
Changes
transfer_queue/storage/managers/base.py: Added class KVStorageManager, implemented key generation and data restoration for ordinary tensors.transfer_queue/storage/managers/yuanrong_manager.py: Added class YuanrongStorageManager, Implemented the validation for the configuration dict.transfer_queue/storage/clients/yuanrong_clients.py: Added class YRStorageClient, encapsulated the calls to the underlying interfaces of Yuanrong DataSystem.transfer_queue/storage/clients/factory.py: Added class StorageClientFactory, implemented a table-driven registration pattern using decorators.tests/test_kv_storage_manager.pyandtests/test_storage_client_factory.py.Testing
Related Links
Summary by CodeRabbit
Release Notes
New Features
Tests
Chores