一句话总结

GenWildSplat 从几张来自互联网的未知相机参数、光照各异的照片中,无需任何场景特定训练,一次前向推理即可生成可实时渲染的 3D 高斯场景。


为什么这个问题重要?

想象你要重建一个地标:手头只有从 Flickr 随机抓取的 5 张游客照——不同时间拍摄,光线完全不同,前景有路人遮挡,没有 GPS 或已知相机参数。

这就是 GenWildSplat 针对的现实场景。

现有方法的痛点:

  • NeRF/3DGS 需要 per-scene 优化:每换一个场景就要训练数十分钟到数小时
  • 依赖受控输入:绝大多数方法要求 COLMAP 提供精确位姿,或假设光照一致
  • 野外”脏数据”:路人、车辆、随机遮挡让基于光度一致性的方法极不稳定

三个核心创新:

  1. Feed-forward 架构:直接从图像预测 3D 高斯,无需测试时优化
  2. Appearance Adapter:解耦场景几何与光照外观
  3. 课程学习:合成 → 真实数据的渐进训练策略

背景知识

3D 表示方式对比

表示方式 内存 渲染速度 稀疏视角友好 可微分
点云 差(无外观)
NeRF(隐式)
体素
3D Gaussian 实时

GenWildSplat 选择 3D Gaussian Splatting(3DGS) 的原因:渲染速度快(>100 FPS)、可微分(支持端到端训练)、参数化表达能力强。

3DGS 渲染基础

每个 3D 高斯由以下参数定义:

\[\mathcal{G}_i = \{\mu_i \in \mathbb{R}^3,\ \Sigma_i \in \mathbb{R}^{3\times 3},\ \alpha_i \in [0,1],\ c_i \in \mathbb{R}^3\}\]

协方差矩阵通过旋转和缩放参数化:$\Sigma = R S S^T R^T$,保证正定性。

渲染时,高斯被投影到图像平面并通过 α-compositing 合成颜色:

\[C(\mathbf{r}) = \sum_{i=1}^{N} c_i \alpha_i \prod_{j=1}^{i-1}(1 - \alpha_j)\]

本质和 NeRF 体渲染一致,只是把连续积分换成离散高斯求和。


核心方法

直觉解释

输入:3-5 张野外图像(无位姿 / 光照不一致 / 有遮挡)
          ↓
   [语义分割] → 生成瞬态物体掩码(行人、车辆)
          ↓
   [深度预测网络] → 每张图的稠密深度估计
          ↓
   [相机估计网络] → 预测各图像间相对位姿
          ↓
   [Gaussian 预测网络] → Canonical 空间中生成 3D 高斯
          ↓
   [Appearance Adapter] → 根据目标光照调制颜色
          ↓
输出:实时可渲染的 3D 高斯场景(推理 < 1s)

核心思路是将原本需要 per-scene 优化的流程变成一次前向传播的预测问题,通过大规模预训练让网络学会”如何从少量图像推断 3D 场景”。

Canonical 空间与位姿预测

传统方法先跑 COLMAP 确定相机位置,GenWildSplat 直接从图像特征预测相对位姿,将所有视角的高斯归一化到标准坐标系

\[\mathbf{G}_{canonical} = \mathcal{F}_\theta\!\left(\{I_k\},\ \{D_k\},\ \{\hat{P}_k\}\right)\]

其中 $\hat{P}_k$ 是网络预测的相机位姿,而非 COLMAP 真值。

Appearance Adapter

野外图像光照差异是最大挑战。Adapter 网络分离几何结构(不变)和外观(光照相关):

\[c_i^{target} = \mathcal{A}(c_i^{source},\ \mathbf{e}_{target})\]

$\mathbf{e}_{target}$ 是从目标参考图像提取的外观嵌入向量。这个设计允许”换光照”——用白天图像重建,用黄昏外观渲染。

课程学习策略

阶段 数据来源 目的
阶段一 合成数据(位姿已知,光照可控) 学习几何先验
阶段二 合成 + 真实混合 适应真实噪声
阶段三 PhotoTourism / MegaScenes 野外泛化

实现

核心数据结构:3D 高斯

import torch
import torch.nn as nn

class GaussianSet(nn.Module):
    """3D 高斯场景表示"""
    
    def __init__(self, n_gaussians: int):
        super().__init__()
        self.means     = nn.Parameter(torch.randn(n_gaussians, 3))        # 中心 (N,3)
        self.quats     = nn.Parameter(torch.randn(n_gaussians, 4))        # 旋转四元数 (N,4)
        self.scales    = nn.Parameter(torch.zeros(n_gaussians, 3))        # 缩放 log-scale
        self.opacities = nn.Parameter(torch.zeros(n_gaussians, 1))        # 不透明度 pre-sigmoid
        self.colors    = nn.Parameter(torch.rand(n_gaussians, 3))         # RGB 颜色

    def get_covariance(self) -> torch.Tensor:
        """四元数 + 缩放 → 协方差矩阵 Σ = RSS^T R^T"""
        q = self.quats / (self.quats.norm(dim=-1, keepdim=True) + 1e-8)
        w, x, y, z = q.unbind(-1)
        R = torch.stack([
            1-2*(y*y+z*z),  2*(x*y-z*w),    2*(x*z+y*w),
            2*(x*y+z*w),    1-2*(x*x+z*z),  2*(y*z-x*w),
            2*(x*z-y*w),    2*(y*z+x*w),    1-2*(x*x+y*y)
        ], dim=-1).reshape(-1, 3, 3)
        S = torch.diag_embed(torch.exp(self.scales))
        return R @ S @ S.transpose(-1,-2) @ R.transpose(-1,-2)

Feed-Forward 高斯预测网络

class GenWildSplatEncoder(nn.Module):
    """
    前向编码器(简化版)
    输入: k 张 RGB 图 + 深度图 + 预测相机位姿
    输出: 3D 高斯参数
    """
    def __init__(self, n_gaussians=8192, feature_dim=256):
        super().__init__()
        # 图像 + 深度特征提取(实际用 DINOv2/ViT 替代此处简单 CNN)
        self.img_encoder = nn.Sequential(
            nn.Conv2d(4, 64, 3, padding=1),  nn.ReLU(),      # 4ch: RGB+Depth
            nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(128, feature_dim, 3, stride=2, padding=1), nn.ReLU(),
        )
        # 跨视角注意力:让各视图"互相感知",这是多视角融合的关键
        self.cross_view_attn = nn.MultiheadAttention(
            feature_dim, num_heads=8, batch_first=True
        )
        # 高斯参数预测头:mean(3) + quat(4) + scale(3) + opacity(1) + color(3)
        self.gaussian_head = nn.Sequential(
            nn.Linear(feature_dim, 512), nn.ReLU(),
            nn.Linear(512, n_gaussians * 14)
        )
        self.n_gaussians = n_gaussians

    def forward(self, images, depths, poses):
        # images: (B, k, 3, H, W),depths: (B, k, 1, H, W)
        B, k, _, H, W = images.shape
        x = torch.cat([images, depths], dim=2).view(B*k, 4, H, W)
        feats = self.img_encoder(x).mean(dim=[-1,-2]).view(B, k, -1)  # GAP

        # 跨视角融合:每个视角的特征都能看到其他视角
        feats, _ = self.cross_view_attn(feats, feats, feats)
        feats = feats.mean(dim=1)                      # (B, feature_dim)

        params = self.gaussian_head(feats)             # (B, N*14)
        return params.view(B, self.n_gaussians, 14)    # 后续分割各参数分量

Appearance Adapter

class AppearanceAdapter(nn.Module):
    """
    外观适配器:将源光照下的高斯颜色转换到目标光照条件
    基本思路类似 AdaIN:提取全局光照统计,调制颜色
    """
    def __init__(self, embed_dim=128):
        super().__init__()
        self.appearance_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=4, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(8), nn.Flatten(),
            nn.Linear(64*64, embed_dim)
        )
        self.color_modulator = nn.Sequential(
            nn.Linear(3 + embed_dim, 64), nn.ReLU(),
            nn.Linear(64, 3), nn.Sigmoid()
        )

    def forward(self, source_colors, target_image):
        """
        source_colors: (N, 3) 高斯颜色
        target_image:  (1, 3, H, W) 目标光照参考图
        """
        embed = self.appearance_encoder(target_image)           # (1, embed_dim)
        embed = embed.expand(source_colors.shape[0], -1)        # (N, embed_dim)
        return self.color_modulator(torch.cat([source_colors, embed], dim=-1))

瞬态物体过滤

def filter_transient_gaussians(gaussians, seg_masks, cameras, threshold=0.7):
    """
    语义掩码投票:若一个高斯在多个视角中均被标记为瞬态区域
    (行人/车辆),则将其不透明度置零
    seg_masks: List[(H,W)] 语义掩码,1=瞬态
    """
    transient_votes = torch.zeros(len(gaussians.means))

    for mask, cam in zip(seg_masks, cameras):
        pts_2d = project_to_image(gaussians.means, cam)  # (N,2),投影函数省略
        in_frame = (pts_2d[:, 0] >= 0) & (pts_2d[:, 0] < mask.shape[1]) & \
                   (pts_2d[:, 1] >= 0) & (pts_2d[:, 1] < mask.shape[0])
        px = pts_2d[in_frame].long()
        transient_votes[in_frame] += mask[px[:, 1], px[:, 0]].float()

    # 超过阈值视为瞬态:sigmoid(-10) ≈ 0,有效"关闭"该高斯
    ratio = transient_votes / len(cameras)
    gaussians.opacities.data[ratio > threshold] = -10.0
    return gaussians

3D 可视化

import open3d as o3d

def visualize_gaussians(gaussian_set, opacity_threshold=0.1):
    """用点云可视化有效高斯中心(预期结果:建筑轮廓清晰可见)"""
    means  = gaussian_set.means.detach().cpu().numpy()
    colors = torch.sigmoid(gaussian_set.colors).detach().cpu().numpy()
    opacities = torch.sigmoid(gaussian_set.opacities).squeeze().detach().cpu().numpy()

    visible = opacities > opacity_threshold
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(means[visible])
    pcd.colors = o3d.utility.Vector3dVector(np.clip(colors[visible], 0, 1))

    o3d.visualization.draw_geometries([pcd],
        window_name="GenWildSplat — Canonical Gaussians")
    # 预期输出:带颜色的场景点云,建筑结构清晰,路人等瞬态物体已被过滤

实验

数据集说明

数据集 特点 主要挑战
PhotoTourism 旅游地标,Flickr 图片 光照跨度大,游客遮挡严重
MegaScenes 100+ 多样化室外场景 视角稀疏,无约束采集

定量评估

方法 PSNR ↑ SSIM ↑ LPIPS ↓ 推理时间 需要优化?
NeRF-W 23.1 0.79 0.21 ~2 h
3DGS + COLMAP 24.8 0.82 0.18 ~30 min
pixelSplat 22.6 0.77 0.24 ~0.5 s
GenWildSplat 25.3 0.84 0.16 < 1 s

推理时间差距是数量级级别——这才是 GenWildSplat 真正的工程价值所在。


工程实践

实际部署考虑

  • 推理速度:A100 < 1s;RTX 3090 约 2-3s;RTX 4090 约 1-1.5s
  • 渲染 FPS:3DGS 渲染本身可达 100+ FPS,适合实时应用
  • 内存:模型权重约 1-2 GB;单场景高斯约 200-500 MB
  • 批量场景:可以并行处理,无需独立优化,适合大规模部署

常见坑

坑 1:图像重叠度不足

# 完全不重叠的图像会让位姿预测失效、高斯分布混乱
# 用特征匹配检查重叠度,匹配点 < 50 时应发出警告
def check_overlap(images, min_matches=50):
    matches = extract_matches(images)   # LoFTR / SuperPoint
    if len(matches) < min_matches:
        raise ValueError(f"图像重叠度不足:{len(matches)} 匹配点,建议 > {min_matches}")

坑 2:极端光照差异

# 夜晚 vs 正午混搭,Appearance Adapter 难以覆盖 5+ stops 曝光差
# 预处理:直方图均衡化 + 过滤极端曝光
import cv2
def normalize_exposure(img):
    lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
    lab[:, :, 0] = cv2.equalizeHist(lab[:, :, 0])
    return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)

坑 3:大场景位姿漂移

相对位姿的误差在大场景中会累积,可以用少量已知 GPS 坐标做锚点对齐:

# 用 Procrustes analysis 求解相似变换,将预测位姿对齐到锚点
def align_to_anchors(pred_poses, anchor_poses, anchor_idx):
    # S* = argmin_S ||S @ pred[anchor_idx] - anchor_poses||
    # 再将 S* 应用到所有预测位姿
    pass

数据采集建议

  • 相邻视角重叠 > 40%,否则特征匹配失败
  • 分辨率至少 720p,建议 1080p
  • 尽量选相近时间段拍摄的图像(减少光照跨度)
  • 场景遮挡严重时,额外采集几张”无人”图像辅助过滤

什么时候用 / 不用?

适用场景 不适用场景
旅游地标快速重建 高精度工业测量(亚毫米级)
互联网图像快速集成 完全动态场景(水流、旗帜飘动)
无 GPU 优化资源 强反射 / 透明材质(玻璃幕墙)
实时预览、AR 应用 室内精细结构(桌面细节)
野外大场景外观生成 需要确定性几何保证的应用

与其他方法对比

方法 优点 缺点 适合场景
NeRF-W 光照建模细腻 小时级训练,无法实时 离线高质量渲染
3DGS + COLMAP 实时渲染,质量高 需 COLMAP 位姿,受控拍摄 有条件的精确重建
pixelSplat Feed-forward 仅支持双视角 立体相机系统
MVSplat 多视角 Feed-forward 必须提供已知位姿 稀疏已知位姿场景
GenWildSplat 无位姿 + 光照自适应 + 实时 精度略低于优化方法 野外互联网图像

我的观点

GenWildSplat 真正的价值不在于比优化方法精度更高,而在于打破了”重建必须有受控数据”的隐性假设。让用户上传几张手机照片就能得到 3D 场景,这才是面向消费者应用的产品逻辑。

离实用还有哪些卡点?

  1. 室内场景:强反射和透明材质对所有 Gaussian 方法都是难题,GenWildSplat 没有例外
  2. 大规模场景:城市级重建需要额外的分块策略,目前方法对整体结构一致性缺乏全局约束
  3. 细节保真度:偶发的”幻觉高斯”(网络臆造的不存在结构)在工程应用中难以容忍

值得关注的开放方向:

  • 将 feed-forward 思路推广到动态场景(4D Gaussian)
  • Appearance Adapter 能否支持极端天气迁移(晴天 → 雪天)?
  • 与 SAM2、DINOv2 等 Foundation Model 更深度集成,提升遮挡处理鲁棒性

Feed-forward 3D 重建是当前最值得押注的方向之一——不是因为它已经完美,而是因为它第一次让“即拍即重建”有了工程上的可行路径。