-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
Dear author,
I hope to replace the multi-scale pyramid module you proposed with another multi-modal fusion module.
It's highly appreciated that you could give me some advise!
I researched and found that this paper: Attention Based Feature Fusion For Multi-Agent Collaborative Perception proposed that the GAT module can be used as a replacement.
I tried it and reproduced it using HEAL's stage1~4. The evaluation results in stage1 and 2 are similar to HEAL's method, but in stage4, the result of use_cav2
is 0.1 less than your result !Besides,the result of use_cav2's AP is less than the results of use_cav1's.
I would like to ask you, what do you think is the problem?
The code and configuration files are modified as follows:
config:
fusion_backbone:
resnext: true
layer_nums: [3, 5, 8]
layer_strides: [1, 2, 2]
num_filters: [64, 128, 256]
upsample_strides: [1, 2, 4]
num_upsample_filter: [128, 128, 128]
anchor_number: *anchor_num
**num_heads: 4 # 添加此行**
heter_pyramid_collab.py / heter_pyramid_single.py:
#self.pyramid_backbone = PyramidFusion(args['fusion_backbone'])
from opencood.models.fuse_modules.gat_pyramid_fuse import GATPyramidFusion
self.pyramid_backbone = GATPyramidFusion(args['fusion_backbone'])
gat_pyramid_fuse.py:
import torch
import torch.nn as nn
import torch.nn.functional as F
from opencood.models.sub_modules.base_bev_backbone_resnet import ResNetBEVBackbone
from opencood.models.sub_modules.torch_transformation_utils import warp_affine_simple
from opencood.models.fuse_modules.fusion_in_one import regroup
class GATPyramidFusion(ResNetBEVBackbone):
def __init__(self, model_cfg, input_channels=64):
super().__init__(model_cfg, input_channels)
self.num_levels = len(model_cfg['num_filters'])
self.align_corners = model_cfg.get('align_corners', False)
self.num_heads = model_cfg.get('num_heads', 4) # 默认4头注意力
# 为每个尺度定义GAT参数
self.gat_queries = nn.ModuleList([nn.Conv2d(C, C // self.num_heads, 1) for C in model_cfg['num_filters']])
self.gat_keys = nn.ModuleList([nn.Conv2d(C, C // self.num_heads, 1) for C in model_cfg['num_filters']])
self.gat_values = nn.ModuleList([nn.Conv2d(C, C, 1) for C in model_cfg['num_filters']])
# 保持占用图头与PyramidFusion一致
for i in range(self.num_levels):
setattr(self, f"single_head_{i}", nn.Conv2d(model_cfg['num_filters'][i], 1, 1))
def gat_fuse(self, features, record_len, affine_matrix, level):
"""
在单尺度上使用GAT融合特征。
Args:
features: [sum(record_len), C, H, W]
record_len: [B] batch中每个样本的代理数
affine_matrix: [B, L, L, 2, 3]
level: 当前尺度索引
Returns:
fused_feature: [B, C, H, W]
"""
split_features = regroup(features, record_len) # List of [N_b, C, H, W]
fused_out = []
for b, feat in enumerate(split_features):
N = record_len[b]
t_matrix = affine_matrix[b][:N, :N, :, :]
warped_features = warp_affine_simple(feat, t_matrix[0], (feat.shape[2], feat.shape[3]), self.align_corners)
# 计算Q, K, V
Q = self.gat_queries[level](warped_features) # [N, C//num_heads, H, W]
K = self.gat_keys[level](warped_features) # [N, C//num_heads, H, W]
V = self.gat_values[level](warped_features) # [N, C, H, W]
# 重塑为注意力计算
N, C_qk, H, W = Q.shape
Q = Q.view(N, self.num_heads, C_qk // self.num_heads, H * W).permute(1, 3, 0, 2) # [heads, H*W, N, C//heads]
K = K.view(N, self.num_heads, C_qk // self.num_heads, H * W).permute(1, 3, 0, 2)
V = V.view(N, self.num_heads, V.shape[1] // self.num_heads, H * W).permute(1, 3, 0, 2)
# 注意力得分
scores = torch.matmul(Q, K.transpose(-1, -2)) / (C_qk // self.num_heads) ** 0.5 # [heads, H*W, N, N]
attn_weights = F.softmax(scores, dim=-1)
# 加权融合
fused = torch.matmul(attn_weights, V) # [heads, H*W, N, C//heads]
fused = fused.permute(2, 0, 3, 1).reshape(N, -1, H, W)[0] # [C, H, W],取ego视角
fused_out.append(fused)
return torch.stack(fused_out)
def forward_collab(self, spatial_features, record_len, affine_matrix, agent_modality_list=None, cam_crop_info=None):
feature_list = self.get_multiscale_feature(spatial_features) # 多尺度特征
fused_features = []
occ_maps = []
for i, feat in enumerate(feature_list):
fused_feat = self.gat_fuse(feat, record_len, affine_matrix, i)
occ_map = getattr(self, f"single_head_{i}")(feat) # 计算占用图
fused_features.append(fused_feat)
occ_maps.append(occ_map)
fused_feature = self.decode_multiscale_feature(fused_features)
return fused_feature, occ_maps
def forward_single(self, spatial_features):
feature_list = self.get_multiscale_feature(spatial_features)
occ_maps = []
for i, feat in enumerate(feature_list):
occ_map = getattr(self, f"single_head_{i}")(feat)
occ_maps.append(occ_map)
fused_feature = self.decode_multiscale_feature(feature_list)
return fused_feature, occ_maps
Metadata
Metadata
Assignees
Labels
No labels