一句话总结

通过引用导向存储(ROS)技术,TensorHub 将 LLM RL 训练的 GPU stall 时间减少最高 6.7x,跨数据中心权重同步提速 19x——核心思路:不复制权重,只追踪权重在哪里。

为什么需要这个?

RL 训练循环的结构

现代 LLM 强化学习(PPO/GRPO)是一个交替循环:

  1. Rollout 阶段:推理集群(rollout workers)用当前策略生成响应
  2. 评分阶段:奖励模型打分
  3. 训练阶段:梯度计算,更新训练集群(training workers)权重
  4. 同步阶段:🚨 将新权重传回推理集群 ← 性能瓶颈在这里

问题在第 4 步:一个 70B 模型的权重约 140 GB(BF16),每次训练迭代后都需要把这 140 GB 从训练 GPU 搬到推理 GPU。

场景 权重大小 推理副本数 每轮需传输
7B 模型 × 8 副本 14 GB 8 112 GB
70B 模型 × 16 副本 140 GB 16 2.24 TB
跨 DC 70B × 32 副本 140 GB 32 4.48 TB

在 400 Gbps RDMA 网络下,传输 2.24 TB 理论需要 44 秒——而训练一步可能只需要 10 秒。权重传输成了 RL 训练的真正瓶颈。

现有方案的问题

方案 A:Parameter Server — 维护中心化权重副本,写放大,中心节点成为瓶颈

方案 B:All-Reduce 广播 — 所有 worker 必须同时参与,无法动态扩缩容

方案 C:P2P 点对点复制 — 需要额外存储空间,不感知网络拓扑

根本问题:它们都在”复制”权重,而权重已经在训练 GPU 上存着了——为什么不直接用那份数据?

核心原理:引用导向存储(ROS)

直觉:图书馆的借阅系统

想象一本书已经有人在阅览室里读着。图书馆不会再复制一本放到参考书架——而是记录”某读者手上有这本书,要借的话去找他”。

ROS 做的是同样的事:

  • 训练完成后,训练 worker 持有新版本权重(就像读者拿着书)
  • ROS 不复制权重,而是在注册表中记录:”版本 V+1 的权重在 worker_0, worker_4 的 GPU 上”
  • 推理 worker 需要权重时,查注册表,直接从持有者处 pull
  • 没有额外存储,没有额外内存,只有元数据

硬件层面:为什么这样更快

GPU 集群中数据传输有明确的带宽层级:

NVLink(同机 GPU 间)     ~  600 GB/s
PCIe(同机 CPU-GPU)      ~   64 GB/s
InfiniBand RDMA(机器间)  ~   50 GB/s(400 Gbps)
跨数据中心网络             ~    5 GB/s(40 Gbps)

ROS 允许拓扑感知路由:推理 worker 优先从同机的训练 worker 拉取(NVLink,快 12x),只有无法避免时才走跨 DC 链路。传统 All-Reduce 不区分这些路径——慢的那个节点决定所有人的速度。

代码实现

Baseline:朴素的权重广播

import torch.distributed as dist

def naive_weight_broadcast(model, trainer_rank=0):
    """
    传统方案:All-Reduce 广播
    缺点:
    1. 所有 worker 必须同时在线(无法弹性扩缩)
    2. 慢的 worker 阻塞所有人
    3. 不感知网络拓扑,跨 DC 也走同一路径
    """
    for name, tensor in model.state_dict().items():
        # 强制全局同步:140GB 模型 × 16 副本 = 2.24TB 传输
        # 即使只有 1 个 worker 需要更新,其他人也必须参与
        dist.broadcast(tensor, src=trainer_rank)

性能分析(70B 模型,16 个推理 worker,400 Gbps RDMA):

  • 理论传输时间:140 GB × 16 / 50 GB/s ≈ 44 秒
  • GPU stall 占总训练时间 60-70%

ROS 注册表核心实现

import threading
from dataclasses import dataclass, field
from typing import Dict, List

@dataclass
class WorkerInfo:
    worker_id: str
    host: str
    rdma_port: int
    nvlink_peers: List[str] = field(default_factory=list)  # 同机 NVLink peer

class ROSRegistry:
    """
    引用导向存储注册表
    核心:不存储权重,只追踪谁持有哪个版本
    """
    def __init__(self):
        self._lock = threading.RLock()
        self._version_holders: Dict[int, List[str]] = {}  # version -> [worker_id]
        self._workers: Dict[str, WorkerInfo] = {}
        self._latest_version = 0

    def commit_version(self, version: int, holder_ids: List[str]):
        """训练完成后,训练 worker 向注册表申报持有新版本"""
        with self._lock:
            self._version_holders[version] = holder_ids
            self._latest_version = max(self._latest_version, version)

    def lookup(self, version: int, requester_id: str) -> WorkerInfo:
        """
        为请求者找到最优的权重提供者
        优先级:同机 NVLink > 同机 PCIe > 跨机 RDMA > 跨 DC
        """
        with self._lock:
            holders = self._version_holders.get(version, [])
            if not holders:
                raise ValueError(f"Version {version} not found in registry")

            requester = self._workers.get(requester_id)

            # 优先选 NVLink peer(同机 GPU,带宽 600 GB/s)
            for holder_id in holders:
                if requester and holder_id in requester.nvlink_peers:
                    return self._workers[holder_id]

            # 其次同主机(PCIe,64 GB/s)
            for holder_id in holders:
                holder = self._workers[holder_id]
                if requester and holder.host == requester.host:
                    return holder

            # 最后走 RDMA(50 GB/s)
            return self._workers[holders[0]]

TensorHub 客户端:RDMA 直接拉取

class TensorHubClient:
    """
    推理 worker 侧客户端
    关键设计:pull 而非 push,由需要的人主动拉取
    这允许不同 worker 按各自节奏更新,无需全局同步
    """
    def __init__(self, registry: ROSRegistry, worker_id: str):
        self.registry = registry
        self.worker_id = worker_id

    def fetch_weights(self, version: int, model) -> None:
        state_dict = {}

        for name, param in model.named_parameters():
            # 每个张量独立查找最优来源(允许不同层来自不同 worker)
            source = self.registry.lookup(version, self.worker_id)

            # RDMA one-sided READ:直接读取远端 GPU 显存
            # 不需要远端 CPU 参与,延迟更低
            tensor = self._rdma_read(
                src_host=source.host,
                src_port=source.rdma_port,
                tensor_key=f"v{version}/{name}",
                shape=param.shape,
                dtype=param.dtype,
                device=param.device
            )
            state_dict[name] = tensor

        # 原子性加载:所有参数同时切换版本,不出现混合版本
        model.load_state_dict(state_dict, strict=True)

    def _rdma_read(self, src_host, src_port, tensor_key,
                   shape, dtype, device):
        # 实际依赖 UCX/libibverbs RDMA API
        # ... (完整实现省略)
        pass

弹性 Rollout:新 Worker 热加入

class ElasticRolloutCoordinator:
    """
    弹性 rollout 的核心价值:
    新 worker 加入时无需暂停训练,无需全局广播
    直接从注册表拉取当前版本,独立完成同步
    """
    def __init__(self, registry: ROSRegistry):
        self.registry = registry

    def add_worker(self, worker_info: WorkerInfo, current_version: int):
        self.registry._workers[worker_info.worker_id] = worker_info

        # 异步拉取当前版本,不阻塞训练主循环
        # 这是 ROS 相比 All-Reduce 快 4.8x 的核心:
        # 新 worker 只拉取最新版本,不需要全局 barrier
        client = TensorHubClient(self.registry, worker_info.worker_id)
        threading.Thread(
            target=client.fetch_weights,
            args=(current_version, self._get_model(worker_info)),
            daemon=True
        ).start()

常见错误:提交时机错误

# ❌ 错误:optimizer.step() 之前就提交版本
def wrong_commit(registry, version, holders):
    registry.commit_version(version, holders)  # 权重还未更新!
    optimizer.step()   # 这才真正改变权重

# ✅ 正确:等待所有训练 worker 完成 all-reduce 后再提交
def correct_commit(registry, version, holders):
    optimizer.step()
    dist.barrier()  # 确保所有训练 worker 的 all-reduce 完成
    registry.commit_version(version, holders)  # 此时权重已就绪

性能实测

测试环境:H100 集群,400 Gbps InfiniBand,CUDA 12.3,70B 模型 BF16

实现版本 场景 GPU Stall 时间 提速比
All-Reduce Standalone 44.2 s 1x(基准)
TensorHub ROS Standalone 6.6 s 6.7x
All-Reduce Elastic(新增 worker) 44.2 s 1x
TensorHub ROS Elastic 9.2 s 4.8x
点对点复制 Cross-DC 210 s 1x
TensorHub ROS Cross-DC 11.1 s 19x

为什么跨 DC 提速最显著(19x)?

传统方案必须等所有 DC 的 worker 完成同步后才能开始下一轮训练,而 DC 间带宽约 40 Gbps(DC 内的 1/10)。TensorHub 用异步流水线解耦传输和训练:

传统方案:[训练] ──── [等待跨DC传输 210s] ──── [下一轮训练]

TensorHub:[训练] ─→ [下一轮训练开始]
                       ↑ 推理 worker 在后台异步拉取,训练不等待

什么时候用 / 不用?

适用场景 不适用场景
训练和推理集群分离的 RL 训练 单机 RL(不涉及网络传输)
弹性扩缩容推理集群 权重更新频率极低(每小时一次)
跨数据中心分布式 RL 小模型(< 1B,传输不是瓶颈)
RDMA 网络可用(InfiniBand) 仅以太网环境(RDMA 特性无法发挥)

真实局限性

  • 训练 worker 的 GPU 显存需保留权重副本,直到所有推理 worker 拉取完成,增加显存压力
  • 版本垃圾回收需要精心设计:旧版本持有者不能过早释放 GPU 内存
  • 依赖 RDMA 基础设施,不是所有云环境都支持

调试技巧

版本不一致: 症状是推理集群不同 worker 生成质量波动不稳定。检查 commit_version 是否在 dist.barrier() 之后调用,确认没有提前提交未完成的版本。

显存 OOM: ROS 的隐患——训练 worker 必须持有旧版本直到推理 worker 确认拉取。监控持有版本数:

def evict_old_versions(registry: ROSRegistry, keep_versions: int = 2):
    """驱逐过旧版本的引用,释放 GPU 显存持有义务"""
    held = sorted(registry._version_holders.keys())
    for old_version in held[:-keep_versions]:
        del registry._version_holders[old_version]

RDMA 传输超时: 检查 IB 网卡的 port_rcv_errors。TensorHub 内置重试——先换一个 holder 重试,不依赖单点。

延伸阅读

  • 原论文TensorHub: Scalable and Elastic Weight Transfer for LLM RL Training,Section 4(ROS 实现细节)和 Section 5(容错设计)值得精读
  • OpenRLHF:开源 LLM RLHF 框架,其权重同步模块是 TensorHub 思路的简化版本,可对比研究
  • RDMA 基础:理解 one-sided READ(不需对端 CPU 参与)和 two-sided Send/Recv 的区别,是理解 TensorHub 高效原因的关键
  • veRL/HybridFlow:另一个生产级 RL 训练框架,采用了类似的 actor/rollout 分离架构,对比两者的权重同步策略很有价值