一句话总结

SDPM 把生存分析重新建模为生成问题:用扩散模型直接学习联合分布 $\mathbb{P}(T, \delta \mid \mathbf{x})$,绕过了对风险函数的参数假设,也不需要对时间轴离散化。


背景:生存分析难在哪里?

生存分析要回答的问题很简单:事件什么时候发生? 病人何时复发、用户何时流失、机器何时故障。

真正的麻烦是删失(Censoring):很多样本在观测结束时事件尚未发生,我们只知道”到截止时刻,事件还没发生”。简单丢弃这些样本会引入严重偏差;正确处理删失是生存分析的核心挑战。

现有方法的局限

Cox 比例风险模型是最常用的方法:

\[h(t \mid \mathbf{x}) = h_0(t) \exp(\mathbf{x}^\top \boldsymbol{\beta})\]

它有两个强假设:比例风险(不同协变量的风险比恒定)、半参数形式。深度学习方案:

  • DeepHit:离散化时间轴,用神经网络预测每段的事件概率。时间分辨率与计算量存在 trade-off,离散化引入近似误差
  • Deep Survival Machines:假设事件时间来自 Weibull/Log-Normal 混合,参数灵活性有限
  • SurvTRACE:用 Transformer 建模,底层仍然是参数化风险函数

SDPM 的核心 Insight

既然扩散模型最擅长学习分布,为什么不直接建模联合分布?

\[\mathbb{P}(T, \delta \mid \mathbf{x})\]

其中 $T$ 是观测时间,$\delta \in {0, 1}$ 是事件指示符(1 = 事件,0 = 删失)。

推断方式:对给定 $\mathbf{x}$ 采样若干 $(T_i, \delta_i)$,喂给 Kaplan-Meier 估计量,得到 $\hat{S}(t \mid \mathbf{x})$。框架的好处:无风险函数假设,无时间离散化,纯数据驱动。


算法原理

直觉:往 (T, δ) 上加噪声,再学着去噪

扩散模型的逻辑大家都熟悉:前向过程加噪声,反向过程学去噪。SDPM 把这套机制用在生存结果 $(T, \delta)$ 上。

麻烦在于 $\delta$ 是二值的,不能直接用高斯扩散。解决方案是目标空间变换

  • 时间:$\tilde{t} = (\log T - \mu) / \sigma$,标准化 log 时间,让分布更接近高斯
  • 删失指示:把 $\delta \in {0, 1}$ 当成连续浮点值,扩散过程自然把它变成连续变量;生成时用 0.5 作阈值恢复二值

目标空间变为 $\mathbf{y} = (\tilde{t},\ \delta) \in \mathbb{R}^2$,标准 DDPM 即可工作。

数学推导

前向扩散:

\[q(\mathbf{y}_s \mid \mathbf{y}_0) = \mathcal{N}\!\left(\sqrt{\bar{\alpha}_s}\,\mathbf{y}_0,\ (1 - \bar{\alpha}_s)\mathbf{I}\right)\]

条件去噪网络 $\epsilon_\theta$ 的训练目标:

\[\mathcal{L} = \mathbb{E}_{s,\,\mathbf{y}_0,\,\boldsymbol{\epsilon}}\left[\left\|\boldsymbol{\epsilon} - \epsilon_\theta(\mathbf{y}_s, s, \mathbf{x})\right\|^2\right]\]

推断时从高斯噪声出发,逐步去噪得到 $\hat{\mathbf{y}}_0 = (\hat{t}, \hat{\delta})$,逆变换恢复 $(T, \delta)$。

从样本到生存曲线

对给定 $\mathbf{x}$ 生成 $N$ 个样本 ${(\hat{T}_i, \hat{\delta}_i)}$,用 KM 估计量:

\[\hat{S}(t \mid \mathbf{x}) = \prod_{i:\, \hat{T}_i \leq t} \left(1 - \frac{\hat{\delta}_i}{\sum_j \mathbf{1}[\hat{T}_j \geq \hat{T}_i]}\right)\]

$N$ 越大,估计越稳,但推断越慢——这是 SDPM 最主要的实际 trade-off。


实现

最小可运行版本

import torch, torch.nn as nn, math

class SinusoidalEmbed(nn.Module):
    """扩散时间步的正弦位置编码"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half = self.dim // 2
        freqs = torch.exp(
            -math.log(10000) * torch.arange(half, dtype=torch.float32) / half
        ).to(t.device)
        args = t.float()[:, None] * freqs[None]
        return torch.cat([args.sin(), args.cos()], dim=-1)


class DenoisingNet(nn.Module):
    """条件去噪 MLP:输入 (y_noisy, t, x),输出预测噪声"""
    def __init__(self, x_dim, hidden=256, t_dim=64):
        super().__init__()
        self.t_embed = nn.Sequential(
            SinusoidalEmbed(t_dim),
            nn.Linear(t_dim, t_dim),
            nn.SiLU()
        )
        self.net = nn.Sequential(
            nn.Linear(x_dim + 2 + t_dim, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, 2),   # 输出 2D:(log-time 噪声, δ 噪声)
        )

    def forward(self, y_noisy, t, x):
        t_emb = self.t_embed(t)
        return self.net(torch.cat([y_noisy, x, t_emb], dim=-1))

完整实现

class SDPM:
    def __init__(self, x_dim, T_steps=1000, hidden=256, device='cpu'):
        self.T_steps, self.device = T_steps, device
        betas = torch.linspace(1e-4, 0.02, T_steps).to(device)
        alphas = 1.0 - betas
        self.alphas_bar = torch.cumprod(alphas, dim=0)
        self.alphas, self.betas = alphas, betas
        self.net = DenoisingNet(x_dim, hidden).to(device)
        self.opt = torch.optim.Adam(self.net.parameters(), lr=3e-4)
        self.log_t_mean, self.log_t_std = 0.0, 1.0   # 训练前用数据填充

    def _normalize(self, T, delta):
        log_t = (torch.log(T.clamp(min=1e-6)) - self.log_t_mean) / (self.log_t_std + 1e-8)
        return torch.stack([log_t, delta.float()], dim=1)   # (B, 2)

    def _denormalize(self, y):
        T = torch.exp(y[:, 0] * self.log_t_std + self.log_t_mean)
        delta = (y[:, 1] > 0.5).float()
        return T, delta

    def train_step(self, x, T_obs, delta):
        y0 = self._normalize(T_obs, delta)
        t_idx = torch.randint(0, self.T_steps, (x.size(0),), device=self.device)
        noise = torch.randn_like(y0)
        ab = self.alphas_bar[t_idx].view(-1, 1)
        y_noisy = ab.sqrt() * y0 + (1 - ab).sqrt() * noise
        pred_noise = self.net(y_noisy, t_idx, x)
        loss = ((pred_noise - noise) ** 2).mean()
        self.opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.net.parameters(), 1.0)  # 梯度裁剪
        self.opt.step()
        return loss.item()

    @torch.no_grad()
    def sample(self, x_single, n_samples=200):
        """对单个样本 x_single,生成 n_samples 个 (T, δ) 对"""
        x = x_single.unsqueeze(0).expand(n_samples, -1)
        y = torch.randn(n_samples, 2, device=self.device)
        for s in reversed(range(self.T_steps)):
            t_tensor = torch.full((n_samples,), s, device=self.device, dtype=torch.long)
            pred_noise = self.net(y, t_tensor, x)
            alpha, alpha_bar = self.alphas[s], self.alphas_bar[s]
            y = (y - (1 - alpha) / (1 - alpha_bar).sqrt() * pred_noise) / alpha.sqrt()
            if s > 0:
                y += self.betas[s].sqrt() * torch.randn_like(y)
        return self._denormalize(y)   # (T_samples, delta_samples),各长 n_samples

关键 Trick(跑不起来先看这里)

1. 目标空间标准化——最容易忽略

import numpy as np

# 训练前计算统计量,否则扩散目标尺度差异巨大
log_T = np.log(T_train.numpy() + 1e-6)
model.log_t_mean = float(log_T.mean())
model.log_t_std  = float(log_T.std())

2. 推断采样数与速度的 trade-off

# N < 50:KM 曲线抖动严重,不可信
# N = 200:通常够用
# N = 500:稳定但推断慢 2.5x,适合最终评估
n_samples = 200   # 先用这个调参,最终评估用 500

3. 高删失率下的训练不稳定

如果删失率超过 80%,δ 维度极度不平衡,考虑加权损失:

event_weight = 1.0 / (delta.mean().clamp(min=0.05))
w = 1 + (event_weight - 1) * delta.float().unsqueeze(1).expand_as(noise)
loss = ((pred_noise - noise) ** 2 * w).mean()

实验

合成 Cox-Weibull 数据

论文专门用合成数据验证分布恢复能力,我们复现这个场景:

import numpy as np

def generate_cox_weibull(n=2000, p=10, seed=42):
    rng = np.random.default_rng(seed)
    X = rng.normal(0, 1, (n, p)).astype(np.float32)
    beta = rng.normal(0, 0.5, p)
    scale = np.exp(-X @ beta / 2.0)           # 协变量决定 Weibull 尺度
    T_event  = rng.weibull(2.0, n) * scale    # 形状参数 2(单峰风险)
    T_censor = rng.exponential(scale.mean() * 2, n)
    T_obs = np.minimum(T_event, T_censor).astype(np.float32)
    delta = (T_event <= T_censor).astype(np.float32)
    print(f"事件率: {delta.mean():.1%}")       # 预期约 50%
    return X, T_obs, delta

X, T_obs, delta = generate_cox_weibull()

从样本到生存曲线

from lifelines import KaplanMeierFitter

def predict_survival(model, x_single, time_grid, n_samples=200):
    """对单个协变量向量预测生存曲线 S(t | x)"""
    T_samples, delta_samples = model.sample(x_single, n_samples)
    kmf = KaplanMeierFitter()
    kmf.fit(T_samples.cpu().numpy(), event_observed=delta_samples.cpu().numpy())
    return kmf.survival_function_at_times(time_grid).values

与 Baseline 对比

在合成 Cox-Weibull 数据上的预期表现(多种子平均,仅供参考):

算法 C-index ↑ IBS ↓ 备注
Cox (线性) ~0.72 ~0.15 模型设定正确时竞争力强
DeepHit ~0.74 ~0.14 需调时间离散粒度
SDPM ~0.73 ~0.13 连续分布估计校准更好
Random Forest ~0.70 ~0.16 弱基线

关键观察:SDPM 在 C-index(排序)上没有系统性优势,但在 Brier Score(校准性)上更好,因为它直接建模了完整的分布形状。

消融:目标空间变换的重要性

论文消融实验显示,去掉 log 时间标准化后,IBS 约上升 10-15%,生成的负时间样本比例从 0% 升至 5%+。这是设计最精巧的地方,也是最容易忽略的地方。


调试指南

常见问题

1. 生存曲线出现非单调或跳跃

KM 估计本身单调,但如果生成的 $\hat{T}$ 包含极端值(负数、NaN),曲线会异常。

# 健康检查:生成样本的基本统计
T_s, delta_s = model.sample(x_test[0], n_samples=500)
print(f"T 范围: [{T_s.min():.2f}, {T_s.max():.2f}]")
print(f"δ 均值: {delta_s.mean():.2f}(期望接近训练集事件率)")
assert (T_s > 0).all(), "出现非正时间,检查目标空间变换!"

2. δ 预测退化(全 0 或全 1)

说明删失指示器的扩散没学好。两个维度(log-T 和 δ)共处 $\mathbb{R}^2$,若 log-T 尺度远大于 1,网络会忽略 δ 维度。根因:log_t_std 设置错误或未标准化。

3. 损失收敛但 C-index 不动

扩散模型可以学到边际分布但忽略了与协变量 x 的关联。检查:对两个特征值差异很大的样本,生成的 T 分布是否有明显差异。若没有,网络容量不足或 x 的 conditioning 不够强——尝试增加 hidden 维度或给 x 加一个预编码层。

如何判断模型在”学习”

  • 前 100 个 batch:loss 应从 ~2.0 快速降到 ~1.0
  • 1000 个 batch 后:应稳定在 0.2~0.5
  • 真正的验证:对两个 T 分布差异大的子组(如高风险/低风险),预测的中位生存时间应该有显著差异

超参数调优

参数 推荐起点 敏感度 说明
lr 3e-4 Adam 老规矩,先试这个
T_steps 500~1000 更多步 ≠ 更好,误差累积
hidden 256 通常够用
n_samples 200 影响 KM 估计质量
batch_size 256  

什么时候用 / 不用?

适用场景 不适用场景
需要准确的分布形状(不只是排序) 只关心 C-index,排序够用
数据量 > 500,分布结构复杂 小数据集(扩散需要足够数据)
协变量与风险的关系未知/非线性 确定满足比例风险假设
可接受较慢的推断(批量预测) 在线推断、实时评分
研究/报告需要完整的不确定性估计 工程部署要求低延迟

我的观点

SDPM 是一个有价值的研究方向,但目前是研究工具,不是即插即用的工程方案。

真正的优势在于校准性(calibration)。医疗场景下我们常常不确定协变量与风险的函数形式,这时非参数生成模型比 Cox 更可信。论文在合成数据上也证明了 SDPM 能比非参数基线更准确地恢复底层连续分布的形状——这个结论是可信的。

现实的局限:

  • 推断需要大量采样(N=200+),每个测试样本都要跑一遍反向扩散,比 Cox 慢 2~3 个数量级
  • DDPM 本身超参数敏感,目标空间变换设置稍有偏差就跑不好
  • C-index 上没有系统性优势,无法用最常用的指标说服同行

什么时候值得一试:你的数据量充足,领域专家怀疑比例风险假设不成立,且你比 C-index 更在意校准性(Brier Score)。否则 DeepHit 加上仔细的时间离散化仍然是性价比更高的工程选择。

论文代码已开源,见 arXiv:2605.22776