Vision Language Adapter
Contents
Motivation
- cross-modal alignment between visual space and text space.
- visual feature compression
cross attention
A single-layer cross-attention module initialized randomly with trainable positon embeddings.
- Qwen-VL
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84# https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C # tgt_size: M # return: M, C src_size = int(math.sqrt(abs_pos.size(0))) tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype if src_size != tgt_size: return F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size, tgt_size), mode="bicubic", align_corners=False, ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) else: return abs_pos class Resampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by (grid_size**2) learnable queries and 2d sincos pos_emb Outputs: A tensor with the shape of (grid_size**2, embed_dim) """ def __init__( self, grid_size, embed_dim, num_heads, kv_dim=None, norm_layer=nn.LayerNorm ): super().__init__() self.num_queries = grid_size ** 2 self.embed_dim = embed_dim self.num_heads = num_heads self.pos_embed = nn.Parameter( torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() ).requires_grad_(False) self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) trunc_normal_(self.query, std=.02) if kv_dim is not None and kv_dim != embed_dim: self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) else: self.kv_proj = nn.Identity() self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x, attn_mask=None): pos_embed = get_abs_pos(self.pos_embed, x.size(1)) x = self.kv_proj(x) x = self.ln_kv(x).permute(1, 0, 2) N = x.shape[1] q = self.ln_q(self.query) out = self.attn( self._repeat(q, N) + self.pos_embed.unsqueeze(1), x + pos_embed.unsqueeze(1), x, attn_mask=attn_mask)[0] return out.permute(1, 0, 2) def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1)
torch.view + MLP(Linear + GELU + Linear)
A single MLP layer to compress adjacent 2x2 tokens into a single token.
- Qwen2-VL
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25# print(huggingface model) (merger): PatchMerger( (ln_q): LayerNorm((1280,), eps=1e-06, elementwise_affine=True) (mlp): Sequential( (0): Linear(in_features=5120, out_features=5120, bias=True) (1): GELU(approximate='none') (2): Linear(in_features=5120, out_features=1536, bias=True) ) ) # modeling.py: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py class PatchMerger(nn.Module): def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = LayerNorm(context_dim, eps=1e-6) self.mlp = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) return x - Qwen2.5-VL
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25# print(huggingface model) (merger): Qwen2_5_VLPatchMerger( (ln_q): Qwen2RMSNorm((1280,), eps=1e-06) (mlp): Sequential( (0): Linear(in_features=5120, out_features=5120, bias=True) (1): GELU(approximate='none') (2): Linear(in_features=5120, out_features=2048, bias=True) ) ) # modeling.py: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py class Qwen2_5_VLPatchMerger(nn.Module): def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) self.ln_q = Qwen2RMSNorm(context_dim, eps=1e-6) self.mlp = nn.Sequential( nn.Linear(self.hidden_size, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) return x - Qwen3-VL
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31# print(huggingface model) (merger): Qwen3VLVisionPatchMerger( (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True) (linear_fc1): Linear(in_features=4096, out_features=4096, bias=True) (act_fn): GELU(approximate='none') (linear_fc2): Linear(in_features=4096, out_features=2048, bias=True) ) (deepstack_merger_list): ModuleList( (0-2): 3 x Qwen3VLVisionPatchMerger( (norm): LayerNorm((4096,), eps=1e-06, elementwise_affine=True) (linear_fc1): Linear(in_features=4096, out_features=4096, bias=True) (act_fn): GELU(approximate='none') (linear_fc2): Linear(in_features=4096, out_features=2048, bias=True) ) ) # modeling.py: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py class Qwen3VLVisionPatchMerger(nn.Module): def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None: super().__init__() self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) self.use_postshuffle_norm = use_postshuffle_norm self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6) self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) self.act_fn = nn.GELU() self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) return x
Conv2d((2, 2), (2, 2)) + GELU + Conv2d((1, 1), (1, 1)) + Linear
- HunyuanOCR
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58# print(huggingface model) (perceive): HunYuanVisionPatchMerger( (proj): Sequential( (0): Conv2d(1152, 2304, kernel_size=(2, 2), stride=(2, 2)) (1): GELU(approximate='none') (2): Conv2d(2304, 4608, kernel_size=(1, 1), stride=(1, 1)) ) (mlp): Linear(in_features=4608, out_features=1024, bias=True) (before_rms): HunYuanVLRMSNorm((1152,), eps=1e-05) (after_rms): HunYuanVLRMSNorm((1024,), eps=1e-05) ) # model.py: https://github.com/huggingface/transformers/blob/82a06db03535c49aa987719ed0746a76093b1ec4/src/transformers/models/hunyuan_vl/modeling_hunyuan_vl.py class HunYuanVisionPatchMerger(nn.Module): def __init__( self, in_channels, out_channels, spatial_merge_size, rms_norm_eps, **kwargs, ): super().__init__() embed_std = out_channels**-0.5 self.spatial_merge_size = spatial_merge_size self.proj = nn.Sequential( nn.Conv2d(in_channels, in_channels * 2, kernel_size=spatial_merge_size, stride=spatial_merge_size), nn.GELU(), nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1), ) self.mlp = nn.Linear(in_channels * 4, out_channels) self.image_newline = nn.Parameter(torch.randn(in_channels * 4) * embed_std) self.image_begin = nn.Parameter(torch.randn(out_channels) * embed_std) self.image_end = nn.Parameter(torch.randn(out_channels) * embed_std) self.image_sep = nn.Parameter(torch.randn(out_channels) * embed_std) self.before_rms = HunYuanVLRMSNorm(in_channels, eps=rms_norm_eps) self.after_rms = HunYuanVLRMSNorm(out_channels, eps=rms_norm_eps) def forward(self, x, size=(16, 16)): x = self.before_rms(x) # b, n, c h, w = size dtype = x.dtype x = x.permute(0, 2, 1).reshape(x.shape[0], -1, int(h.item()), int(w.item())) # b, c, h, w. n = hxw x = self.proj(x) # b, 4c, h//2, w//2 b, c, h, w = x.shape x = torch.cat( [x, self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype, non_blocking=True)], dim=-1 ) # b, 4c, h//2, w//2+1 x = x.reshape(b, c, -1).permute(0, 2, 1) # b, 4c, n. n= h//2 * (w//2+1) x = self.mlp(x) # b, c, n begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True) end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True) x = torch.cat([begin, x, end], dim=1) return self.after_rms(x)
| |
| |
| |
| |