一句话总结

给模型一张图片和一句话(”找出左边那个反光的金属杯子”),它不仅能把目标分割出来,还能告诉你”视觉推理”是怎么一步步走到这个结论的——这就是 SegCompass 解决的问题。

为什么这个问题重要?

推理分割(Reasoning Segmentation)比普通实例分割难一个量级:

  • 普通分割:给出类别名 → 框出区域(”cat” → 框猫)
  • 推理分割:给出自然语言描述 → 理解语义 → 定位并分割(”厨房里最可能用来喝水的器皿” → 分割出杯子)

这对机器人操作、AR 交互、医疗图像理解都有直接价值。但现有方法面临一个核心矛盾:

方法 可推理 可解释 端到端
隐式查询对齐(如 LISA) ✗ 黑盒
文本定位读出 △ 可读但不可控
SegCompass ✓ 稀疏概念空间

LISA 风格的方法把语言和视觉特征硬对齐成一个隐向量,根本看不出”它在想什么”。SegCompass 的核心创新是:在语言推理和视觉定位之间插入一个稀疏自编码器(SAE),强制让这个连接变得可检查、可追踪

背景知识

推理分割的难点

推理分割需要两种能力同时在线:

  1. 语言推理:理解复杂指令,做多步推断
  2. 视觉感知:把推断结果映射回像素坐标

这两者的特征空间完全不同。语言模型输出 token 序列,视觉编码器输出空间特征图。如何让”厨房 → 金属表面 → 杯子”这条推理链”落地”到图像像素,是核心难题。

Sparse Autoencoder(SAE)是什么?

SAE 来自机械可解释性(Mechanistic Interpretability)研究,用来解析神经网络内部的激活。核心思想:

人类概念是稀疏的——当你看到”猫”,你的”猫”概念激活,”飞机”概念不激活。

SAE 把密集的神经激活投影到一个高维稀疏空间

\[\mathbf{h} = \text{ReLU}(W_{\text{enc}} \mathbf{x} + \mathbf{b}) \quad \text{(大多数维度为 0)}\] \[\hat{\mathbf{x}} = W_{\text{dec}} \mathbf{h} + \mathbf{b}' \quad \text{(重建原始特征)}\] \[\mathcal{L}_{\text{SAE}} = \underbrace{\|\mathbf{x} - \hat{\mathbf{x}}\|^2}_{\text{重建损失}} + \lambda \underbrace{\|\mathbf{h}\|_1}_{\text{稀疏惩罚}}\]

每个非零维度对应一个可解释的”概念单元”。Anthropic 的研究表明,SAE 隐藏维度能对应到真实语义概念(如”金属纹理”、”圆柱形状”)。

核心方法

直觉解释

想象你是个翻译官,要把语言推理的”草稿纸”翻译给视觉系统:

文字推理:"图中金属光泽、圆柱形、在桌子左边"
          ↓ 共享 SAE 编码
稀疏概念:[金属纹理:0.9, 圆柱形:0.7, 左侧:0.8, 其余维度≈0]
          ↓ Slot Mapper 空间定位
热力图:  每个概念激活哪些像素?
          ↓ 多 Slot 汇总
最终掩码:0/1 像素分类

SAE 的作用是把两种模态的信号压缩到同一个稀疏概念词典里,概念维度就成了语言和视觉之间真正可检查的中间层。

数学细节

共享 SAE 编码:给定语言特征 $\mathbf{f}^{\text{lang}}$ 和视觉 token 特征 $\mathbf{f}^{\text{vis}}_i$,用同一个 SAE 编码:

\[\mathbf{h}^{\text{lang}} = \text{SAE}_{\text{enc}}(\mathbf{f}^{\text{lang}}) \in \mathbb{R}^{D_s}\] \[\mathbf{h}^{\text{vis}}_i = \text{SAE}_{\text{enc}}(\mathbf{f}^{\text{vis}}_i) \in \mathbb{R}^{D_s}, \quad i = 1,\ldots,HW\]

$D_s \gg D$(如 16384 vs 1024),大部分维度为零。

查询码本选择:从语言稀疏激活中选出激活值最大的 $K$ 个维度,视为”推理触发的核心概念”:

\[\mathcal{K} = \text{TopK}(\mathbf{h}^{\text{lang}}, K)\]

Slot Mapper 空间定位,每个概念生成独立热力图:

\[\alpha_{k,i} = \text{softmax}_i\!\left(\frac{\mathbf{q}_k \cdot \mathbf{h}^{\text{vis}}_i}{\sqrt{D_s}}\right), \quad \mathbf{M}_k = \sum_i \alpha_{k,i} \cdot \text{pos}(i) \in \mathbb{R}^{H \times W}\]

联合训练(三路损失):

\[\mathcal{L} = \mathcal{L}_{\text{seg}} + \mathcal{L}_{\text{SAE}} + \mathcal{L}_{\text{RL}}\]
  • $\mathcal{L}_{\text{seg}}$:Dice + BCE 分割监督
  • $\mathcal{L}_{\text{SAE}}$:重建 + L1 稀疏惩罚
  • $\mathcal{L}_{\text{RL}}$:GRPO 强化学习,优化推理轨迹质量

Pipeline 概览

图像 + 自然语言指令
  → [视觉编码器 + MLLM] → CoT 推理轨迹 + 视觉特征图
  → [共享 SAE]           → 稀疏语言概念 + 稀疏视觉概念
  → [Query Codebook]     → Top-K 核心概念选择
  → [Slot Mapper]        → K 张空间热力图
  → [Mask Decoder]       → 最终分割掩码

实现

SAE 核心实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class SparseAutoencoder(nn.Module):
    """
    稀疏自编码器:将密集特征映射到高维稀疏概念空间
    - encoder: ReLU 天然稀疏
    - L1 惩罚控制稀疏度
    - decoder 列归一化防止权重坍塌
    """
    def __init__(self, input_dim: int, hidden_dim: int, lambda_sparse: float = 1e-3):
        super().__init__()
        self.encoder = nn.Linear(input_dim, hidden_dim)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)
        self.lambda_sparse = lambda_sparse
        nn.init.orthogonal_(self.encoder.weight)
        nn.init.zeros_(self.encoder.bias)

    def forward(self, x: torch.Tensor):
        # x: [B, seq_len, D] 或 [B, D]
        h = F.relu(self.encoder(x))       # 稀疏概念激活 [B, ..., D_s]
        x_hat = self.decoder(h)           # 重建原始特征
        recon_loss = F.mse_loss(x_hat, x)
        sparsity_loss = h.abs().mean()    # L1 稀疏惩罚
        total_loss = recon_loss + self.lambda_sparse * sparsity_loss
        return h, total_loss

    @torch.no_grad()
    def normalize_decoder(self):
        """训练中定期调用:decoder 列归一化,防止大权重绕过稀疏约束"""
        self.decoder.weight.data = F.normalize(self.decoder.weight.data, dim=0)

Query Codebook + Slot Mapper

class SlotMapper(nn.Module):
    """
    将选中的稀疏概念映射到图像空间热力图
    每个概念 slot 独立生成一张注意力热力图
    """
    def __init__(self, concept_dim: int, num_slots: int, spatial_size: int):
        super().__init__()
        self.num_slots = num_slots
        # 可学习的空间位置 embedding,注入几何信息
        self.pos_embed = nn.Parameter(
            torch.randn(1, spatial_size * spatial_size, concept_dim) * 0.02
        )

    def forward(
        self,
        concept_queries: torch.Tensor,  # [B, K, D_s]  选中的概念向量
        visual_sparse: torch.Tensor,    # [B, H*W, D_s] 视觉稀疏特征
        spatial_size: tuple,
    ) -> torch.Tensor:
        H, W = spatial_size
        vis = visual_sparse + self.pos_embed[:, :H*W]  # 注入位置信息
        # 点积注意力:概念 query 与视觉 token 匹配
        attn = torch.einsum('bkd,bnd->bkn', concept_queries, vis) / (vis.shape[-1] ** 0.5)
        heatmaps = attn.softmax(dim=-1).reshape(-1, self.num_slots, H, W)
        return heatmaps   # [B, K, H, W],每个 slot 一张热力图

完整推理分割 Sketch

class SegCompassMini(nn.Module):
    """
    SegCompass 核心数据流示意(去掉 MLLM,假设特征已提取)
    SAE → 概念选择 → 热力图 → 掩码 完整路径
    """
    def __init__(self, feat_dim=1024, sparse_dim=8192, num_slots=8, spatial_size=32):
        super().__init__()
        self.sae = SparseAutoencoder(feat_dim, sparse_dim)
        self.slot_mapper = SlotMapper(sparse_dim, num_slots, spatial_size)
        self.mask_head = nn.Sequential(
            nn.Conv2d(num_slots, 32, 3, padding=1), nn.ReLU(),
            nn.Conv2d(32, 1, 1),
        )
        self.num_slots = num_slots

    def forward(self, lang_feat, vis_feat, spatial_size=(32, 32)):
        # lang_feat: [B, D]      语言侧输出(CoT 末尾 token)
        # vis_feat:  [B, H*W, D] 视觉编码特征
        lang_sparse, lang_loss = self.sae(lang_feat)     # 语言概念
        vis_sparse, vis_loss = self.sae(vis_feat)         # 视觉概念(共享权重!)

        # Top-K 语言概念作为 slot 查询
        top_vals, top_idx = lang_sparse.topk(self.num_slots, dim=-1)
        # 用激活值加权概念向量(简化版 codebook)
        concept_queries = lang_sparse.unsqueeze(1).expand(-1, self.num_slots, -1)
        mask_select = torch.zeros_like(concept_queries).scatter_(
            2, top_idx.unsqueeze(-1).expand(-1, -1, lang_sparse.shape[-1]), 1.0
        )
        concept_queries = concept_queries * mask_select   # [B, K, D_s]

        heatmaps = self.slot_mapper(concept_queries, vis_sparse, spatial_size)
        mask_logits = self.mask_head(heatmaps)            # [B, 1, H, W]
        sae_loss = (lang_loss + vis_loss) / 2
        return mask_logits, sae_loss, heatmaps            # 热力图用于可解释性分析

结果可视化:热力图解释性分析

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def visualize_concept_heatmaps(image, heatmaps, concept_labels=None, top_k=4):
    """
    可视化每个概念 slot 在图像上的激活区域
    这是 SegCompass "可解释性" 的核心体现
    """
    B, K, H, W = heatmaps.shape
    fig = plt.figure(figsize=(4 * (top_k + 1), 4))
    gs = gridspec.GridSpec(1, top_k + 1)

    ax0 = fig.add_subplot(gs[0, 0])
    ax0.imshow(image); ax0.set_title("Input Image"); ax0.axis('off')

    for i in range(min(top_k, K)):
        ax = fig.add_subplot(gs[0, i + 1])
        hm = heatmaps[0, i].detach().cpu().numpy()
        ax.imshow(image, alpha=0.6)
        ax.imshow(hm, cmap='jet', alpha=0.5)
        label = concept_labels[i] if concept_labels else f"Concept #{i}"
        ax.set_title(label, fontsize=9); ax.axis('off')

    plt.tight_layout()
    plt.savefig("concept_heatmaps.png", dpi=150, bbox_inches='tight')

热力图输出示意:每个 slot 对应推理链中的一个概念,颜色越红表示该概念在图像该位置的激活越强。这就是 SegCompass 比 LISA 更可解释的直接体现——你能看到”金属纹理”这个概念究竟激活了图像的哪个区域。

实验

数据集说明

数据集 规模 特点
ReasonSeg val/test ~1.2k 图 主要推理分割 benchmark,需要多步推断
RefCOCO / RefCOCO+ 大规模 指代表达分割,指令相对简单
gRefCOCO 中等 通用化指代分割

ReasonSeg 是核心战场——指令复杂(”找出看起来最危险的动物”),需要真正的推理能力。

定量评估(ReasonSeg 主要结果)

方法 gIoU (val) cIoU (val) 可解释性 开源
LISA-7B 49.9 52.9
PixelLM 46.7 51.2
PSALM 50.4 53.1
SegCompass ~52+ ~54+

核心论点:在性能持平甚至更优的前提下,获得了可解释性——这是帕累托改进,不是 trade-off。

定性分析

好的案例:带有明确属性描述的指令(”最大的红色物体”)→ 颜色概念和尺寸概念 slot 激活清晰,热力图精确定位目标。

失败案例:空间关系复杂时(”第三个窗口左边的第二盆花”),需要多步计数推理,概念 slot 可能把多个候选区域都高亮,最终掩码不够精确。

工程实践

实际部署考虑

  • 推理速度:MLLM 做 CoT 是瓶颈。7B 模型在 A100 上约 2-4 秒/样本,离实时还差得远
  • 内存占用sparse_dim 通常设为 8192~16384,额外增加约 512MB 参数量
  • 批量推理:CoT 长度不固定,动态批量很麻烦,建议固定最大 CoT 长度并 padding

SAE 训练的常见坑

坑 1:特征坍塌(Feature Collapse)

症状:sparsity_loss 很低但 recon_loss 很高——所有输入映射到同一稀疏模式。

# 修复:optimizer.step() 后立即调用 decoder 列归一化
optimizer.step()
model.sae.normalize_decoder()   # 见前面 SparseAutoencoder 定义

坑 2:稀疏度不够

# 诊断:检查活跃比例,目标 < 5%
def check_sparsity(h: torch.Tensor):
    frac = (h > 0).float().mean().item()
    print(f"活跃比例: {frac:.3f}")  # 若 > 0.1 则需增大 lambda_sparse

# 或改用 TopK-SAE:强制只保留 K 个激活,彻底规避软约束失效问题
h_topk = torch.zeros_like(h).scatter_(-1, h.topk(32, dim=-1).indices, h.topk(32, dim=-1).values)

坑 3:RL 和分割损失梯度冲突

RL 优化推理质量,分割损失优化像素精度,二者梯度方向可能冲突。建议先单独预训练 SAE 和 Mask Decoder,冻结后再 finetune RL 部分。

数据采集建议

ReasonSeg 风格的数据很贵——每条需要人工写复杂推理指令。可以用 GPT-4o 自动生成初版指令,再人工过滤低质量样本,成本可降至 1/5。

什么时候用 / 不用?

适用场景 不适用场景
需要解释”为什么分割这里” 需要实时运行(<100ms)
复杂语义推理(多属性、关系推断) 只做简单类别分割
医疗/法律等需要可审计性的场景 嵌入式设备部署
数据标注质量控制(检查模型逻辑) 动态场景(视频流)

与其他方法对比

方法 优点 缺点 适用场景
LISA 简单有效,端到端 黑盒,不可解释 需要高性能的通用场景
文本定位读出 人类可读 非端到端,后处理不可控 快速原型
SegCompass 可解释+高性能+端到端 推理慢,工程复杂 可审计性要求高的场景

我的观点

SegCompass 走在了一个有趣的十字路口上:把 LLM 可解释性领域的工具(SAE)搬进了多模态感知。这个思路比单纯堆性能要有意思得多。

但有几点保留意见:

“可解释性”有多深? 论文展示了热力图和概念激活的相关性,但 SAE 的概念维度是否真的对应人类语义概念,还是只是统计上的稀疏分解?需要更严格的用户研究来验证,而不只是定量 IoU。

速度是最大的工程瓶颈。 CoT 推理 + SAE 编码这条路径不可能实时,对于机器人操作等延迟敏感应用暂时用不上,必须等 LLM 推理提速若干量级才有工程价值。

更有趣的延伸是双向控制:如果 SAE 概念空间是共享的,能否直接用文本描述来”写入”特定概念的激活,实现更精细的语言引导分割?比如”抑制圆形概念,增强金属纹理概念”——这方向还没有人做。

和三维感知的结合:如果把这套可解释对齐机制放到三维表示上——比如用语言概念直接定位三维高斯基元(3DGS)——将是很有价值的研究方向。空间概念(”左边”、”前方”)天然适合在三维空间中做可解释的定位。

官方代码:github.com/ZhenyuLU-Heliodore/SegCompass