Skip to content
This repository was archived by the owner on Jan 21, 2026. It is now read-only.

[Feat]: Support for CPU and Non-Tensor Data Storage of Yuanrong#107

Merged
0oshowero0 merged 16 commits into
devfrom
dpj/nontensor
Nov 24, 2025
Merged

[Feat]: Support for CPU and Non-Tensor Data Storage of Yuanrong#107
0oshowero0 merged 16 commits into
devfrom
dpj/nontensor

Conversation

@dpj135

@dpj135 dpj135 commented Nov 13, 2025

Copy link
Copy Markdown
Contributor

Description

The current YRStorageClient only supports storage of NPU tensors and cannot handle CPU tensors or non-tensor data (e.g., bool, str, list, etc.). To improve its generality and system flexibility, this PR refactors the YRStorageClient in transfer_queue/storage/clients/yuanrong_client.py and updates related storagemanagers to uniformly support the following data types:

  • CPU tensors (torch.Tensor on CPU)
  • Non-tensor objects (arbitrary serializable Python objects)
  • (Future) NPU tensors (currently unstable due to limitations in the underlying storage system)

Key Changes

  • Introduced a dedicated _cpu_ds_client in YRStorageClient to handle all non-NPU data, including CPU tensors and generic Python objects.

  • Removed the device_id parameter from YRStorageClient.__init__, decoupling the client from explicit device binding.

  • Optimized the put / get / clear interfaces to efficiently support large-volume data operations.

  • Adapted KVStorageManager and YuanrongStorageManager to correctly interface with the new YRStorageClient API.

Files Modified

  • storage/clients/yuanrong_client.py
  • storage/managers/base.py
  • storage/managers/yuanrong_manager.py

Performance test

  • Setup
    • Tensor.shape=[1024]
    • Data system single-machine deployment
  • Result
    • [NPU]: NPU tensors
    • [CPU]: CPU tensors
    • Note: Performance may fluctuate to some extent. For 100,000 keys, the put-get-clear operation can be completed in as fast as 40 seconds. The time includes the cold start of the data system.
image
Codes (Click) : Test performance of three methods `get/put/clear`
# benchmark.py
import time
from tracemalloc import Statistic
import torch
import torch_npu
import statistics
from storageclient_dev import YuanrongStorageClient
from async_storageclient import YuanrongStorageClient_Async

def quick_tensor_list_equal(a, b):
    return len(a) == len(b) and all(torch.equal(x, y) for x, y in zip(a, b))
def test_npu(n_batches):
    client = YuanrongStorageClient(config)
    t1=[]
    for i in range(5):
        keys = [f"k_{i}" for i in range(num_keys)]
        #vals = [torch.randn(tensor_shape) for _ in range(num_keys)]
        vals = [torch.randn(tensor_shape).npu(npu_id) for _ in range(num_keys)]
        shapes = [v.shape for v in vals]   
        dtypes = [v.dtype for v in vals]
        start = time.time()
        client.put(keys, vals)
        ret_vals = client.get(keys, shapes,dtypes)
        assert quick_tensor_list_equal(vals, ret_vals)
        client.clear(keys)
        t1.append(time.time() - start)
    return statistics.mean(sorted(t1)[1:-1])

def test_cpu(num_keys):
    client = YuanrongStorageClient(config)
    t2=[]
    for i in range(5):
        keys = [f"k_{i}" for i in range(num_keys)]
        vals = [torch.randn(tensor_shape) for _ in range(num_keys)]
        #vals = [torch.randn(tensor_shape).npu(npu_id) for _ in range(num_keys)]
        shapes = [v.shape for v in vals]   
        dtypes = [v.dtype for v in vals]
        start = time.time()
        client.put(keys, vals)
        ret_vals = client.get(keys, shapes,dtypes)
        assert quick_tensor_list_equal(vals, ret_vals)
        client.clear(keys)
        t2.append(time.time() - start)
    return statistics.mean(sorted(t2)[1:-1])

if __name__ == "__main__":
    npu_id=5
    tensor_shape=1024
    torch_npu.npu.set_device(f'npu:{npu_id}')
    config={"host":"127.0.0.1", "port":31511}
    for N_BATCHES in [5, 20, 100, 200, 500, 800, 1000]:
        num_keys = N_BATCHES * 100
        print(f"Testing with {num_keys} keys ")
    
        # NPU
        avg_t1 = test_npu(num_keys)
        print(f"[NPU]          Time: {avg_t1:.3f}s")
        
        # CPU  
        avg_t2 = test_cpu(num_keys)
        print(f"[CPU]         Time: {avg_t2:.3f}s")
    
    

Integration test

  • Environment
    • OS: eulerOS2r10
    • Architure: aarch64
    • NPU: 910B
    • chariot-ds: 0.1.6
  • Test:
    Modify the code in async_demo.py as follows:
Codes (Click) : Integration test with the entire TransferQueue
# Copyright 2025 The TransferQueue Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
import math
import os
import sys
import time
from pathlib import Path
try:
    import torch_npu
    print("NPU device count:", torch_npu.npu.device_count())
except Exception as e:
    print("NPU not available:", e)
import ray
import torch
from omegaconf import OmegaConf
from tensordict import NonTensorData, TensorDict

parent_dir = Path(__file__).resolve().parent.parent.parent
sys.path.append(str(parent_dir))

from transfer_queue import (  # noqa: E402
    AsyncTransferQueueClient,
    BatchMeta,
    SimpleStorageUnit,
    TransferQueueController,
    process_zmq_server_info,
)
from transfer_queue.utils.utils import get_placement_group  # noqa: E402

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["RAY_DEBUG"] = "1"
ray.init()


def compute_old_log_prob(data1, data2):
    time.sleep(3)
    return data1


def generate_sequences(data):
    time.sleep(3)
    return data


class ActorRolloutRefWorker:
    def actor_rollout_wg_generate_sequences(self, data_meta, data_system_client):
        # 1. Pull real data from the storage plane through client based on data_meta
        data = asyncio.run(data_system_client.async_get_data(data_meta))
        logger.info(f"demo get data->generate_sequences {data}")

        output = generate_sequences(data["input_ids"])

        output = TensorDict(
            {
                "generate_sequences_ids": output,
                "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]),
                "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(output.size(0))]),
            },
            batch_size=output.size(0),
        )

        # 2. Write results back to the storage plane based on data_meta
        asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
        data_meta.add_fields(output)
        logger.info("demo put data to storages done")

        return data_meta

    def actor_rollout_wg_compute_old_log_prob(self, data_meta, data_system_client):
        # 1. Pull real data from the storage plane through client based on data_meta
        data = asyncio.run(data_system_client.async_get_data(data_meta))
        logger.info(f"demo get data->old_log_prob {data}")

        output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"])

        output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))

        # 2. Write results back to the storage plane based on data_meta
        asyncio.run(data_system_client.async_put(data=output, metadata=data_meta))
        data_meta.add_fields(output)
        logger.info("demo put data to storages done")

        return data_meta


@ray.remote(resources={"NPU": 1})
class AsyncvLLMServer:
    def __init__(self, config, data_system_controller_info):
        self.config = config
        self.data_system_client = AsyncTransferQueueClient(
            client_id="AsyncvLLMServer",
            controller_info=data_system_controller_info,
        )

        self.data_system_client.initialize_storage_manager(manager_type="YuanrongStorageManager", config=self.config)

    async def generate(self, data_meta):
        data = await self.data_system_client.async_get_data(data_meta)
        logger.info(f"demo get data->generate_sequences {data}")

        data = data["input_ids"]
        data += 1
        await asyncio.sleep(3)

        output = TensorDict(
            {
                "generate_sequences_ids": data,
                "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(data.size(0))]),
                "nested_tensor": torch.nested.as_nested_tensor([torch.randn(1, 2) for _ in range(data.size(0))]),
            },
            batch_size=data.size(0),
        )

        await self.data_system_client.async_put(data=output, metadata=data_meta)
        logger.info("demo Async Server put data to storages done")

        return data_meta


@ray.remote(num_cpus=1, resources={"NPU": 1})
class AsyncRolloutWorker:
    def __init__(
        self,
        config,
        data_system_controller_info,
    ):
        self.async_vllm_server = AsyncvLLMServer.remote(
            config,
            data_system_controller_info,
        )

    async def generate_sequences(self, data_meta_chunk):
        tasks = []
        for i in range(data_meta_chunk.size):
            # asyncio.create_task cannot directly call Ray Actor methods,
            # otherwise an error will be reported:a coroutine was expected, got ObjectRef(xxx)
            tasks.append(asyncio.create_task(self.generate(data_meta_chunk[i])))
        data_metas = await asyncio.gather(*tasks)
        return BatchMeta.concat(data_metas)

    async def generate(self, data_meta):
        data_meta_new = await self.async_vllm_server.generate.remote(data_meta)
        return data_meta_new


class RolloutManager:
    def __init__(self, config, data_system_storage_unit_infos, data_system_controller_info):
        self.config = config

        self.data_system_client = AsyncTransferQueueClient(
            client_id="RolloutManager",
            controller_info=data_system_controller_info,
        )

        self.data_system_client.initialize_storage_manager(manager_type="YuanrongStorageManager", config=self.config)

        self.async_rollout_workers = []
        num_workers = self.config.rollout_agent_num_workers
        for i in range(num_workers):
            self.async_rollout_workers.append(AsyncRolloutWorker.remote(config, data_system_controller_info))

    def generate_sequences(self, data_meta):
        data_meta_chunkes = data_meta.chunk(len(self.async_rollout_workers))
        data_metas = ray.get(
            [
                worker.generate_sequences.remote(data_meta_chunk)
                for worker, data_meta_chunk in zip(self.async_rollout_workers, data_meta_chunkes, strict=True)
            ]
        )
        batch_meta = BatchMeta.concat(data_metas)
        logger.info(f"batch_meta: {batch_meta}")

        return batch_meta


class Trainer:
    def __init__(self, config):
        self.config = config
        self.data_system_client = self._initialize_data_system()
        self.actor_rollout_wg = ActorRolloutRefWorker()
        self.async_rollout_manager = RolloutManager(
            self.config,
            self.data_system_storage_unit_infos,
            self.data_system_controller_info,
        )

    def _initialize_data_system(self):
        # 1. Initialize TransferQueueStorage
        total_storage_size = self.config.global_batch_size * self.config.num_global_batch * self.config.num_n_samples
        self.data_system_storage_units = {}
        storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1)
        for storage_unit_rank in range(self.config.num_data_storage_units):
            storage_node = SimpleStorageUnit.options(
                placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank
            ).remote(storage_unit_size=math.ceil(total_storage_size / self.config.num_data_storage_units))
            self.data_system_storage_units[storage_unit_rank] = storage_node
            logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.")

        # 2. Initialize TransferQueueController (single controller only)

        # Sampler usage instructions:
        # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler:
        # Option 1: Pass sampler class (will be instantiated automatically)
        # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler)

        # Option 2: Pass sampler instance (if you need custom configuration)
        # grpo_sampler = GRPOGroupNSampler()
        # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler)

        # Then use sampling_config in get_meta calls:
        # sampling_config={"n_samples_per_prompt": 4}
        self.data_system_controller = TransferQueueController.remote()
        logger.info("TransferQueueController has been created.")

        # 3. Prepare necessary information
        self.data_system_controller_info = process_zmq_server_info(self.data_system_controller)
        self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)

        tq_config = OmegaConf.create({}, flags={"allow_objects": True})  # Note: Need to generate a new DictConfig
        # with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict
        tq_config.controller_info = self.data_system_controller_info
        tq_config.storage_unit_infos = self.data_system_storage_unit_infos
        self.config = OmegaConf.merge(tq_config, self.config)

        # 4. Create client
        self.data_system_client = AsyncTransferQueueClient(
            client_id="Trainer",
            controller_info=self.data_system_controller_info,
        )

        self.data_system_client.initialize_storage_manager(manager_type="YuanrongStorageManager", config=self.config)
        # Note: The client contains ZMQ objects. Currently, we cannot transmit the same client instance
        # to multiple places, as this will cause serialization errors in Ray.
        # Workaround: If you need to use a client in multiple Ray actors or processes, create a separate
        # AsyncTransferQueueClient instance for each actor/process instead of sharing or transmitting the same instance.
        return self.data_system_client

    def fit(self):
        for epoch in range(1):
            train_dataloader = 1
            for step in range(train_dataloader):
                input_ids = (
                    torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111], [200, 222], [300, 333]])
                ) * (step + 1)
                input_ids_repeated = torch.repeat_interleave(input_ids, self.config.num_n_samples, dim=0)
                prompt_batch = TensorDict(
                    {"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated},
                    batch_size=input_ids_repeated.size(0),
                )

                asyncio.run(self.data_system_client.async_put(data=prompt_batch, partition_id=f"train_{step}"))

                logger.info("demo put prompts ok! ")
                time.sleep(5)

                batch_meta = asyncio.run(
                    self.data_system_client.async_get_meta(
                        data_fields=["input_ids", "attention_mask"],
                        batch_size=self.config.global_batch_size * self.config.num_n_samples,
                        partition_id=f"train_{step}",
                        task_name="generate_sequences",
                    )
                )
                logger.info(f"demo get meta {batch_meta}")

                # Simulate calling the generate sequences task of the worker group
                if not self.config.async_rollout_mode:
                    batch_meta = self.actor_rollout_wg.actor_rollout_wg_generate_sequences(
                        batch_meta, self.data_system_client
                    )
                else:
                    batch_meta = self.async_rollout_manager.generate_sequences(batch_meta)
                log_prob_meta = asyncio.run(
                    self.data_system_client.async_get_meta(
                        data_fields=["input_ids", "attention_mask", "generate_sequences_ids"],
                        batch_size=self.config.global_batch_size * self.config.num_n_samples,
                        partition_id=f"train_{step}",
                        task_name="compute_old_log_prob",
                    )
                )
                logger.info(f"demo get log prob meta: {log_prob_meta}")

                # Simulate calling the compute old log prob task of the worker group
                old_log_prob_meta = self.actor_rollout_wg.actor_rollout_wg_compute_old_log_prob(
                    log_prob_meta, self.data_system_client
                )

                batch_meta = batch_meta.union(old_log_prob_meta)

                # Client notifies controller to clear data status, controller returns metadata;
                # Client then notifies the storage plane to clear based on metadata
                asyncio.run(self.data_system_client.async_clear(partition_id=f"train_{step}"))
                logger.info("clear ok! ")
        logger.info("demo done!")

        # Cleanup resources
        self.data_system_client.close()
        return batch_meta


if __name__ == "__main__":
    # NOTE: you may choose to set async_rollout_mode=True to test the async rollout mode that mimics
    # AgentLoopManager in verl
    config_str = """
      global_batch_size: 8
      num_global_batch: 1
      num_data_storage_units: 2
      async_rollout_mode: True
      rollout_agent_num_workers: 2
      num_n_samples: 2
      host: '127.0.0.1'
      port: 31511
      client_name: 'YuanrongStorageClient'
    """
    dict_conf = OmegaConf.create(config_str)

    trainer = Trainer(dict_conf)
    trainer.fit()
    
    ray.shutdown()
   
    print('demo finished!')

Note: If you want to use the NPU and achieve faster tensor access speed, please install torch_npu and specify the required resources in ray.remote:

@ray.remote(num_cpus=1, resources={"NPU": 1})
class AsyncRolloutWorker:

TODO

  • Solve NPU-related reliability issues and improve the robustness of YRStorageClient.
  • Test more non-tensor types.
  • Conduct performance testing about npu tensors and cpu object.
  • Integrate the entire TransferQueue to test.

@coderabbitai

coderabbitai Bot commented Nov 13, 2025

Copy link
Copy Markdown

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

This pull request introduces a partition-based data management system replacing the previous global-step model, adds extensible sampler abstractions for flexible data consumption strategies, implements a pluggable key-value storage backend architecture with support for multiple storage implementations (Yuanrong, KV stores), and updates APIs, documentation, and test coverage to reflect these foundational changes.

Changes

Cohort / File(s) Summary
Partition ID Migration
transfer_queue/metadata.py, transfer_queue/client.py
Replaced global_step parameter/field with partition_id (string) across SampleMeta, AsyncTransferQueueClient, and TransferQueueClient public APIs; updated error handling and logging to use partition identifiers.
Sampler Abstraction
transfer_queue/sampler/base.py, transfer_queue/sampler/sequential_sampler.py, transfer_queue/sampler/grpo_group_n_sampler.py, transfer_queue/sampler/__init__.py, transfer_queue/__init__.py
Introduced BaseSampler abstract base class and two concrete implementations (SequentialSampler, GRPOGroupNSampler) for pluggable data consumption strategies; exported via package __init__.py files.
Partition-Based Controller
transfer_queue/controller.py, tests/test_controller.py, tests/test_controller_data_partitions.py
Implemented PartitionIndexManager and DataPartitionStatus classes supporting dynamic partition creation, index allocation, production/consumption tracking, and field metadata management; refactored TransferQueueController with partition-aware APIs and sampler integration.
KV Storage Backend
transfer_queue/storage/clients/base.py, transfer_queue/storage/clients/factory.py, transfer_queue/storage/clients/yuanrong_client.py, transfer_queue/storage/clients/__init__.py
Introduced TransferQueueStorageKVClient abstract interface, StorageClientFactory with decorator-based registration, and YRStorageClient for Yuanrong DataSystem integration with NPU/CPU fallback support.
Storage Manager Refactoring
transfer_queue/storage/managers/base.py, transfer_queue/storage/managers/factory.py, transfer_queue/storage/managers/simple_backend_manager.py, transfer_queue/storage/managers/yuanrong_manager.py
Added KVStorageManager with key-value serialization helpers, converted factory registration to decorator pattern, introduced AsyncSimpleStorageManager with partition awareness and configurable timeouts, and added YuanrongStorageManager for KV-backed Yuanrong storage.
Recipe Usage Updates
recipe/simple_use_case/async_demo.py, recipe/simple_use_case/sync_demo.py
Updated client invocations to use partition_id instead of global_step, removed get_n_samples parameter usage, added sampler usage documentation, and adjusted data push/clear operations to partition-based identifiers.
Test Suite Expansion
tests/test_async_simple_storage_manager.py, tests/test_client.py, tests/test_put.py, tests/test_samplers.py, tests/test_kv_storage_manager.py, tests/test_storage_client_factory.py
Updated existing tests to use partition_id instead of global_step; added comprehensive test suites for sampler behavior, partition controller functionality, KV storage utilities, async put operations, and storage client factory.
Supporting Changes
.gitignore, transfer_queue/storage/simple_backend.py, transfer_queue/version/version, README.md
Added PyCharm IDE ignore entries; adjusted TensorDict materialization in simple backend; bumped version to 0.1.1.dev2; extensively rewrote README documenting new customization, storage backends, integration patterns, and roadmap updates.

Sequence Diagram

sequenceDiagram
    participant Client as TransferQueueClient
    participant Controller as TransferQueueController
    participant Sampler as BaseSampler
    participant Storage as StorageManager
    
    rect rgb(200, 230, 255)
    Note over Client,Storage: Data Insertion Flow (partition_id="train_0")
    Client->>Controller: put(data, partition_id="train_0")
    Controller->>Controller: async_get_meta(partition_id)
    Controller->>Controller: allocate_indexes via PartitionIndexManager
    Controller->>Storage: notify_data_update(partition_id, fields, indexes, dtypes, shapes)
    Storage->>Storage: put_data (KV serialization)
    end
    
    rect rgb(230, 200, 255)
    Note over Client,Storage: Data Retrieval Flow (sampling)
    Client->>Controller: get_meta(data_fields, batch_size, partition_id)
    Controller->>Controller: get_partition(partition_id)
    Controller->>Controller: scan_data_status (find ready indexes)
    Controller->>Sampler: sample(ready_indexes, batch_size)
    Sampler-->>Controller: (sampled_indexes, consumed_indexes)
    Controller->>Controller: generate_batch_meta from sampled indexes
    Controller-->>Client: BatchMeta
    Client->>Storage: get_data(metadata)
    Storage->>Storage: get_data (KV retrieval + merge)
    Storage-->>Client: TensorDict
    end
    
    rect rgb(200, 255, 230)
    Note over Client,Storage: Cleanup Flow
    Client->>Controller: clear(partition_id="train_0")
    Controller->>Storage: clear_data(partition_id)
    Storage->>Storage: Remove KV entries
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • High-density logic changes: PartitionIndexManager, DataPartitionStatus, and new controller partition management require careful review of index allocation, consumption tracking, and dynamic expansion logic.
  • Multiple public API signature changes: partition_id replacing global_step across client, metadata, and controller layers affects call-site consistency and error handling.
  • New storage abstraction layer: KVStorageManager key-value mapping and StorageClientFactory registration patterns introduce new serialization and factory mechanisms requiring validation.
  • Sampler integration complexity: BaseSampler and concrete implementations interact with controller's batch generation—verify sampling correctness and edge case handling.
  • Extensive cross-file refactoring: Changes span 20+ files with interconnected concerns; verify partition_id threading is consistent throughout and that old global_step references are fully replaced.

Poem

🐰 Hop hop! Partitions bloom where steps once were,
Samplers dance with freedom's blur,
Storage backends stack like carrots tall,
One queue to manage them all!
~Rabbit's burrow echo'd with refactor's cheer


Note

🎁 Summarized by CodeRabbit Free

Your organization is on the Free plan. CodeRabbit will generate a high-level summary and a walkthrough for each pull request. For a comprehensive line-by-line review, please upgrade your subscription to CodeRabbit Pro by visiting https://app.coderabbit.ai/login.

Comment @coderabbitai help to get the list of available commands and usage tips.

@dpj135 dpj135 changed the base branch from main to dev November 13, 2025 10:59
@dpj135 dpj135 changed the title [Feat]: Support for CPU and Non-Tensor Data Storage [Feat]: Support for CPU and Non-Tensor Data Storage of Yuanrong Nov 14, 2025
"This results in the inability to quickly store and retrieve tensors on the NPU side,"
"which may affect performance."
)
elif not torch.npu.is_available():

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will lead to AttributeError when torch_npu is not available

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if package torch_npu is not available, TORCH_NPU_IMPORTED will be False, codes in line:53 don't run.

def _create_empty_tensorlist(self, shapes, dtypes):
self._npu_ds_client = None
self._cpu_ds_client = None
self.DS_CLIENT_KEYS_LIMIT = 1000

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set as environment variable?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines +67 to +69
def npu_ds_client_is_available(self):
return False
# return self._npu_ds_client is not None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always False?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and maybe use @Property is better

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

npu_ds_client_is_available should be a boolean expression, I think it's better to represent it using a method.

Comment on lines +129 to +133
total_count = (len(keys) + self.DS_CLIENT_KEYS_LIMIT - 1) // self.DS_CLIENT_KEYS_LIMIT
for i in range(total_count):
start_idx = self.DS_CLIENT_KEYS_LIMIT * i
end_idx = min(self.DS_CLIENT_KEYS_LIMIT * (i + 1), len(keys))
self._batch_put(keys[start_idx:end_idx], values[start_idx:end_idx])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, why there is a self.DS_CLIENT_KEYS_LIMIT? Is it related to performance?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we perform self._batch_put in parallel, will the performance be better?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, why there is a self.DS_CLIENT_KEYS_LIMIT? Is it related to performance?

It's a constraint of the keys in the data system.

@dpj135 dpj135 Nov 20, 2025

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use asyncio to product self._batch_put in parallel like:

async def put(self, keys: list[str], values: list[Any]):
        if not isinstance(keys, list) or not isinstance(values, list):
            raise ValueError("keys and values must be lists")
        if len(keys) != len(values):
            raise ValueError("Number of keys must match number of values")

        total_count = (len(keys) + DS_CLIENT_KEYS_LIMIT - 1) // DS_CLIENT_KEYS_LIMIT

        tasks = []
        for i in range(total_count):
            start_idx = DS_CLIENT_KEYS_LIMIT * i
            end_idx = min(DS_CLIENT_KEYS_LIMIT * (i + 1), len(keys))
            tasks.append(self._batch_put(keys[start_idx:end_idx], values[start_idx:end_idx]))

        await asyncio.gather(*tasks)

The results are as follows:
image

The results of synchronous and asynchronous operations about cpu_tensors are very similar. The reason may be that DSTensorClient and KVClient is Synchronous blocking.

Comment on lines -276 to -297
def close(self) -> None:
"""Close all ZMQ sockets and context to prevent resource leaks."""
for sock in (self.controller_handshake_socket, self.data_status_update_socket):
try:
if sock and not sock.closed:
sock.close(linger=0)
except Exception as e:
logger.error(f"[{self.storage_manager_id}]: Error closing socket {sock}: {str(e)}")

try:
if self.zmq_context:
self.zmq_context.term()
except Exception as e:
logger.error(f"[{self.storage_manager_id}]: Error terminating zmq_context: {str(e)}")

def __del__(self):
"""Destructor to ensure resources are cleaned up."""
try:
self.close()
except Exception as e:
logger.error(f"[{self.storage_manager_id}]: Exception during __del__: {str(e)}")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these codes are required, otherwise in verl the code might break when switching from rollout model to train mode, or vllm goes to sleep mode

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

logger.info("Missing 'client_name in config, using default value('Yuanrong')")
config["client_name"] = "Yuanrong"
elif client_name != "Yuanrong":
raise ValueError("Invalid 'client_name' in config")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("Invalid 'client_name' in config")
raise ValueError(f"Invalid 'client_name': {client_name} in config. Expecting 'Yuanrong' for YuanrongStorageManager".)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



# TODO: DSTensorClient.dev_mget has wrong behavior: it may require stricter environment to execute
@StorageClientFactory.register("Yuanrong")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@StorageClientFactory.register("Yuanrong")
@StorageClientFactory.register("YRStorageClient")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@dpj135 dpj135 requested a review from 0oshowero0 November 21, 2025 03:42
Comment on lines +379 to +393
merged_data = {}
for field, data_list in grouped_data.items():
if all(isinstance(item, torch.Tensor) for item in data_list):
try:
merged_data[field] = torch.stack(data_list)
except RuntimeError:
try:
# Fallback to nested tensor if shapes are irregular
merged_data[field] = torch.nested.as_nested_tensor(data_list)
except Exception:
merged_data[field] = NonTensorStack(*data_list)
else:
merged_data[field] = NonTensorStack(*data_list)

return TensorDict(merged_data, batch_size=len(global_indexes))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can add TODO now

Comment on lines 373 to 376
for field in field_names:
for _ in range(len(global_indexes)):
merged_data[field].append(values[value_idx])
grouped_data[field].append(values[value_idx])
value_idx += 1

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add TODO: performance optimize

@0oshowero0 0oshowero0 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@0oshowero0 0oshowero0 merged commit 7a1dc10 into dev Nov 24, 2025
3 checks passed
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants