一句话总结

DiT-IC 将预训练 text-to-image DiT 改造为单步图像重建模型,在 32x 下采样的紧凑潜空间中完成扩散,解码速度比主流扩散编解码器快 30x,16GB 笔记本 GPU 即可重建 2048×2048 图像。


为什么现有方案不够好?

图像压缩有两个互相对立的目标:低比特率高感知质量。传统方法(JPEG、BPG)在低码率下产生块效应;神经网络方法(ELIC、Cheng2020)感知质量有上限;而扩散模型天然懂得”图像应该长什么样”,理论上能以最少比特重建出感知最好的图像。

问题出在架构选择上。

U-Net 的潜空间太浅

现有扩散编解码器(HiFiC-diff、PerCo)几乎清一色使用 U-Net,而 U-Net 的跳接结构要求保留高分辨率特征图,只能在 8x 下采样的浅层潜空间运行。而传统 VAE 编解码器工作在 16x–64x 的深度潜空间。

对于一张 512×512 图像:

方法 潜空间分辨率 计算规模
VAE 编解码器 8×8(64x 下采样) 基准
U-Net 扩散 64×64(8x 下采样) 64x 面积
+ 多步采样(50步) 同上 3200x 等效计算

计算量爆炸的结果:U-Net 扩散编解码器解码一张 512×512 图像要 8–12 秒,2K 分辨率直接 OOM。

DiT-IC 的核心问题:把 U-Net 换成 DiT,让扩散在 32x 下采样的紧凑潜空间运行,重建质量能保住吗?

答案是肯定的——但需要三个对齐机制解决随之而来的问题。


核心原理

为什么 DiT 适合深度潜空间?

DiT(Diffusion Transformer)的核心是全局自注意力,每个 token 能看到所有其他 token。对于 32x 下采样的潜变量,一张 512×512 图像变成 16×16 = 256 个 token,这对 Transformer 完全不是负担。

相比之下,U-Net 的局部卷积在高度压缩的潜空间中反而是劣势——256 个 token 里的全局语义依赖,局部感受野抓不住。

三个对齐机制

DiT-IC 的挑战是把预训练多步 T2I DiT 改造为单步图像重建模型,需要解决三个根本问题:

问题 1:压缩后的潜变量不同位置信息丢失量不同(高频细节丢失多,低频语义丢失少),标准扩散的固定去噪强度不合适。

方差引导重建流:根据编码器输出的局部方差,为每个位置自适应选择去噪强度 $t$。方差大 → 丢失多 → 更强的生成。

问题 2:原始 DiT 的潜变量分布来自 T2I 训练,与图像压缩编码器定义的潜空间几何不匹配。

自蒸馏对齐:用编码器自身的多步扩散轨迹作教师,训练单步模型在编码器定义的潜空间几何中保持一致性。

问题 3:T2I DiT 依赖文本提示,但压缩推理时没有文本。

潜变量条件引导:用压缩潜变量本身替代文本嵌入作为语义条件,编码器输出天然携带原始图像的语义信息。


代码实现

Baseline:传统 VAE 压缩的局限

先看现有方案的核心问题在哪:

# 传统 VAE 编解码器:编码器 → 量化 → 解码器,无扩散先验
class NaiveVAECodec(nn.Module):
    def __init__(self, encoder, quantizer, decoder):
        super().__init__()
        self.encoder = encoder      # 输入: [B,3,H,W] → [B,C,H/32,W/32]
        self.quantizer = quantizer  # 量化 + 熵编码
        self.decoder = decoder      # 重建: [B,C,H/32,W/32] → [B,3,H,W]

    def forward(self, x):
        z = self.encoder(x)          # 32x 下采样潜变量
        z_hat, bpp = self.quantizer(z)
        x_hat = self.decoder(z_hat)  # 低码率下高频细节严重丢失
        return x_hat, bpp

# 问题:低码率下 decoder 没有先验知识补全丢失的细节
# z_hat 里高频信息已损毁,纯 CNN 解码器无法"脑补"纹理

瓶颈:VAE 解码器只会做插值和上采样,对于量化损毁的高频细节束手无策。这正是扩散先验的用武之地。

核心架构:带潜变量条件的 DiT 块

import torch
import torch.nn as nn
import torch.nn.functional as F

class LatentConditionedDiTBlock(nn.Module):
    """
    DiT-IC 的基础模块:用压缩潜变量替代文本条件
    关键改动:cross-attention 的 key/value 来自压缩潜变量,而非文本嵌入
    """
    def __init__(self, dim: int, num_heads: int = 8):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm_cross = nn.LayerNorm(dim)

        # 自注意力:噪声潜变量 token 内部的全局依赖
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        # 交叉注意力:条件来自压缩潜变量(替代文本),注入语义信息
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)

        # adaLN:时间步调制(标准 DiT 设计,控制去噪强度)
        self.adaLN_mod = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim))
        self.ffn = nn.Sequential(
            nn.Linear(dim, 4 * dim), nn.GELU(), nn.Linear(4 * dim, dim)
        )

    def forward(self, x, t_emb, latent_cond):
        """
        x:           噪声潜变量 tokens  [B, N, D]
        t_emb:       时间步嵌入         [B, D]
        latent_cond: 压缩潜变量条件     [B, M, D](提供语义引导,M 可不等于 N)
        """
        shift_sa, scale_sa, gate_sa, shift_ffn, scale_ffn, gate_ffn = \
            self.adaLN_mod(t_emb).chunk(6, dim=-1)

        # 1. 自注意力(全局感受野,这是 U-Net 做不到的)
        h = self.norm1(x) * (1 + scale_sa[:, None]) + shift_sa[:, None]
        h, _ = self.self_attn(h, h, h)
        x = x + gate_sa[:, None] * h

        # 2. 交叉注意力(核心:注入压缩潜变量的语义)
        h, _ = self.cross_attn(self.norm_cross(x), latent_cond, latent_cond)
        x = x + h  # 不用 adaLN gate,让语义条件直接作用

        # 3. FFN
        h = self.norm2(x) * (1 + scale_ffn[:, None]) + shift_ffn[:, None]
        x = x + gate_ffn[:, None] * self.ffn(h)
        return x

方差引导重建流

标准扩散对所有位置用相同的噪声强度 $t$,但压缩潜变量各位置丢失信息量不同。方差引导流将 $t$ 变为逐位置自适应:

\[x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon, \quad t = f(\text{Var}(z_{\text{enc}}))\]
def variance_guided_noising(
    z_clean: torch.Tensor,        # 编码器干净输出 [B, C, H, W]
    encoder_var: torch.Tensor,    # 编码器不确定性估计 [B, C, H, W]
    alphas_cumprod: torch.Tensor, # 噪声调度表 [T]
    t_min: float = 0.1,
    t_max: float = 0.8,           # 上限防止信息完全损毁
):
    """
    方差越大 → 信息丢失越多 → 给更大的 t → 扩散负责补全更多细节
    方差越小 → 压缩损失小 → 给更小的 t → 轻微修复,保留原始信息
    """
    # 归一化到 [0, 1],再映射到 [t_min, t_max]
    var_norm = (encoder_var - encoder_var.amin(dim=(1,2,3), keepdim=True)) / \
               (encoder_var.amax(dim=(1,2,3), keepdim=True) + 1e-8)
    t_map = t_min + var_norm * (t_max - t_min)  # [B, C, H, W]

    # 根据连续时间步插值 alpha(逐位置)
    T = len(alphas_cumprod)
    t_idx = (t_map * (T - 1)).long().clamp(0, T - 1)
    alpha = alphas_cumprod[t_idx]  # [B, C, H, W]

    noise = torch.randn_like(z_clean)
    z_t = torch.sqrt(alpha) * z_clean + torch.sqrt(1 - alpha) * noise
    return z_t, t_map, noise

自蒸馏对齐损失

这是从多步变单步的关键。教师是多步扩散轨迹终点,学生是单步预测;几何约束保持结果在编码器定义的潜空间流形上。

def self_distillation_loss(
    student_z0: torch.Tensor,    # 单步模型预测的 z0 [B, C, H, W]
    teacher_z0: torch.Tensor,    # 多步扩散轨迹终点 [B, C, H, W]
    encoder_z0: torch.Tensor,    # 编码器定义的真实潜变量(几何锚点)
    lambda_c: float = 0.5,       # 一致性权重
    lambda_g: float = 0.1,       # 几何约束权重
) -> torch.Tensor:
    # 一致性损失:鼓励单步预测逼近多步最终结果
    loss_consistency = F.mse_loss(student_z0, teacher_z0)

    # 几何约束:student 输出方向应与编码器潜变量方向一致
    # 直觉:防止扩散"幻觉"偏离编码器编码的语义
    s_norm = F.normalize(student_z0.flatten(1), dim=-1)
    e_norm = F.normalize(encoder_z0.flatten(1), dim=-1)
    loss_geom = 1.0 - (s_norm * e_norm).sum(dim=-1).mean()

    return lambda_c * loss_consistency + lambda_g * loss_geom

单步推理管线

class DiTICDecoder(nn.Module):
    def __init__(self, dit_blocks, vae_decoder, patch_embed, t_embed, patch_size=2):
        super().__init__()
        self.blocks = nn.ModuleList(dit_blocks)
        self.vae_decoder = vae_decoder  # 将精炼后的潜变量解码为像素
        self.patch_embed = patch_embed
        self.t_embed = t_embed
        dim = dit_blocks[0].ffn[0].in_features
        self.final_norm = nn.LayerNorm(dim)
        self.unpatch = nn.Linear(dim, patch_size**2 * 4)  # 还原潜变量通道
        self.patch_size = patch_size

    @torch.no_grad()
    def decode(self, z_compressed, encoder_var, alphas_cumprod):
        B, C, Hl, Wl = z_compressed.shape

        # 1. 方差引导加噪(使用全图平均 t 做单步推理的代理)
        t_mean = encoder_var.mean(dim=(1,2,3)).clamp(0.1, 0.8)  # [B]
        T = len(alphas_cumprod)
        t_idx = (t_mean * (T - 1)).long()
        alpha = alphas_cumprod[t_idx].view(B, 1, 1, 1)
        z_t = torch.sqrt(alpha) * z_compressed + \
              torch.sqrt(1 - alpha) * torch.randn_like(z_compressed)

        # 2. patch 化(潜变量 → token 序列)
        x = self.patch_embed(z_t)                    # [B, N, D]
        cond = self.patch_embed(z_compressed)         # 压缩潜变量作为语义条件
        t_emb = self.t_embed(t_mean)

        # 3. DiT 单步去噪(无迭代,这是速度的关键)
        for block in self.blocks:
            x = block(x, t_emb, cond)

        # 4. 还原 token → 潜变量 → 像素
        x = self.unpatch(self.final_norm(x))          # [B, N, p*p*C]
        p = self.patch_size
        z0_pred = x.view(B, Hl // p, Wl // p, p, p, 4)
        z0_pred = z0_pred.permute(0,5,1,3,2,4).reshape(B, 4, Hl, Wl)
        return self.vae_decoder(z0_pred)              # [B, 3, H, W]

常见错误

错误 1:跳过前向加噪

# ❌ 错误:直接把压缩潜变量当 x_t 输入,DiT 不知道需要多大强度的去噪
x_pred = dit(z_compressed, t=0.5, cond=z_compressed)

# ✓ 正确:先加噪到 t 时刻,再让 DiT 预测 x_0
z_t = add_variance_guided_noise(z_compressed, encoder_var)
x_pred = dit(z_t, t=t_from_var, cond=z_compressed)

错误 2:把自蒸馏的 lambda_g 设太大

几何约束太强会让模型过度拟合编码器输出,失去扩散先验补全细节的能力。经验值:从 lambda_g=0 调通一致性损失后,再从 0.01 慢慢增大。


性能实测

基于论文报告数据(A100 80GB,CLIC 2021 测试集,0.08 bpp):

方法 架构 解码时间 (512²) 最大可处理分辨率 (16GB GPU) LPIPS↓
HiFiC-diff U-Net + 多步 ~12 s ≈512×512 0.12
PerCo U-Net + 多步 ~8 s ≈512×512 0.11
ELIC 纯 VAE 0.05 s 2048×2048+ 0.18
DiT-IC DiT + 单步 ~0.4 s 2048×2048 0.09

为什么快 30x? 计算量是两个维度的乘积:

  • 潜空间面积:从 64×64(8x下采样)→ 16×16(32x下采样),减少 16x
  • 扩散步数:从 50 步 → 1 步,减少 50x
  • 综合:理论上限 800x,实测约 30x(Transformer 注意力有固定开销)

什么时候用 / 不用?

适用场景 不适用场景
低码率(< 0.1 bpp)感知优先场景 需要严格 PSNR/SSIM 指标(医疗影像)
显存受限的边缘设备(16GB 以下) 实时流媒体(单步仍需 ~400ms)
高分辨率内容分发(2K+) 卫星/工业图像(分布外泛化风险)
视觉内容平台(感知质量优先) 编码端资源极度受限(编码器本身也有成本)

重要局限

  1. 单步扩散在极低码率下偶尔产生幻觉(hallucination)——生成了语义合理但实际不存在的细节
  2. 预训练 DiT 依赖大规模 T2I 数据,迁移到医疗/遥感等特定领域需要微调
  3. 编码端速度未在论文中详细报告——算术编码仍可能是实际部署瓶颈

调试技巧

验证方差引导有效性:训练中可视化 t_map 的空间分布,图像边缘和纹理区域的平均 $t$ 应明显高于天空、平面等低频区域。若分布均匀,说明编码器没有产生有意义的不确定性估计。

Flash Attention 加速推理:DiT 的 attention 层是生产部署的计算瓶颈,对于 256 个 token 的序列,Flash Attention 2 节省约 60% 显存且有 2–3x 加速:

# pip install flash-attn --no-build-isolation
from flash_attn.modules.mha import MHA
# 直接替换 LatentConditionedDiTBlock 中的 nn.MultiheadAttention
self.self_attn = MHA(dim, num_heads, use_flash_attn=True)

梯度监控:自蒸馏训练的前期,loss_consistency 应快速下降,而 loss_geom 下降应更缓慢。若 loss_geom 不下降,通常是编码器冻结不够彻底——确保反向传播不流入编码器。


延伸阅读


总结

DiT-IC 的核心贡献是一套系统性解决方案

  1. DiT 替换 U-Net → 打破了潜空间深度限制(8x → 32x),计算规模缩小 16x
  2. 三个对齐机制 → 解决了预训练多步生成模型迁移到单步压缩重建的三个根本障碍
  3. 单步扩散 → 去掉多步迭代,解码速度从 10 秒级变为亚秒级

这套思路的价值超越图像压缩本身:把预训练大模型的生成先验,通过对齐机制约束到特定任务的信息瓶颈,是当前生成式 AI 落地的核心范式。理解 DiT-IC,也就理解了如何把 Stable Diffusion 这类生成模型变成可实际部署的工程产品。