Skip to content

Some questions about replacing pyramid_fuse.py with gat_pyramid_fuse.py (self-implemented) #51

@chinagalaxy2002

Description

@chinagalaxy2002

Dear author,

I hope to replace the multi-scale pyramid module you proposed with another multi-modal fusion module.
It's highly appreciated that you could give me some advise!
I researched and found that this paper: Attention Based Feature Fusion For Multi-Agent Collaborative Perception proposed that the GAT module can be used as a replacement.
I tried it and reproduced it using HEAL's stage1~4. The evaluation results in stage1 and 2 are similar to HEAL's method, but in stage4, the result of use_cav2 is 0.1 less than your result !Besides,the result of use_cav2's AP is less than the results of use_cav1's.
I would like to ask you, what do you think is the problem?
The code and configuration files are modified as follows:

config:

fusion_backbone:
  resnext: true
  layer_nums: [3, 5, 8]
  layer_strides: [1, 2, 2]
  num_filters: [64, 128, 256]
  upsample_strides: [1, 2, 4]
  num_upsample_filter: [128, 128, 128]
  anchor_number: *anchor_num
  **num_heads: 4  # 添加此行**

heter_pyramid_collab.py / heter_pyramid_single.py:

#self.pyramid_backbone = PyramidFusion(args['fusion_backbone'])

from opencood.models.fuse_modules.gat_pyramid_fuse import GATPyramidFusion

self.pyramid_backbone = GATPyramidFusion(args['fusion_backbone'])

gat_pyramid_fuse.py:

import torch
import torch.nn as nn
import torch.nn.functional as F
from opencood.models.sub_modules.base_bev_backbone_resnet import ResNetBEVBackbone
from opencood.models.sub_modules.torch_transformation_utils import warp_affine_simple
from opencood.models.fuse_modules.fusion_in_one import regroup

class GATPyramidFusion(ResNetBEVBackbone):
    def __init__(self, model_cfg, input_channels=64):
        super().__init__(model_cfg, input_channels)
        self.num_levels = len(model_cfg['num_filters'])
        self.align_corners = model_cfg.get('align_corners', False)
        self.num_heads = model_cfg.get('num_heads', 4)  # 默认4头注意力
        
        # 为每个尺度定义GAT参数
        self.gat_queries = nn.ModuleList([nn.Conv2d(C, C // self.num_heads, 1) for C in model_cfg['num_filters']])
        self.gat_keys = nn.ModuleList([nn.Conv2d(C, C // self.num_heads, 1) for C in model_cfg['num_filters']])
        self.gat_values = nn.ModuleList([nn.Conv2d(C, C, 1) for C in model_cfg['num_filters']])
        
        # 保持占用图头与PyramidFusion一致
        for i in range(self.num_levels):
            setattr(self, f"single_head_{i}", nn.Conv2d(model_cfg['num_filters'][i], 1, 1))

    def gat_fuse(self, features, record_len, affine_matrix, level):
        """
        在单尺度上使用GAT融合特征。
        Args:
            features: [sum(record_len), C, H, W]
            record_len: [B] batch中每个样本的代理数
            affine_matrix: [B, L, L, 2, 3]
            level: 当前尺度索引
        Returns:
            fused_feature: [B, C, H, W]
        """
        split_features = regroup(features, record_len)  # List of [N_b, C, H, W]
        fused_out = []
        
        for b, feat in enumerate(split_features):
            N = record_len[b]
            t_matrix = affine_matrix[b][:N, :N, :, :]
            warped_features = warp_affine_simple(feat, t_matrix[0], (feat.shape[2], feat.shape[3]), self.align_corners)
            
            # 计算Q, K, V
            Q = self.gat_queries[level](warped_features)  # [N, C//num_heads, H, W]
            K = self.gat_keys[level](warped_features)     # [N, C//num_heads, H, W]
            V = self.gat_values[level](warped_features)   # [N, C, H, W]
            
            # 重塑为注意力计算
            N, C_qk, H, W = Q.shape
            Q = Q.view(N, self.num_heads, C_qk // self.num_heads, H * W).permute(1, 3, 0, 2)  # [heads, H*W, N, C//heads]
            K = K.view(N, self.num_heads, C_qk // self.num_heads, H * W).permute(1, 3, 0, 2)
            V = V.view(N, self.num_heads, V.shape[1] // self.num_heads, H * W).permute(1, 3, 0, 2)
            
            # 注意力得分
            scores = torch.matmul(Q, K.transpose(-1, -2)) / (C_qk // self.num_heads) ** 0.5  # [heads, H*W, N, N]
            attn_weights = F.softmax(scores, dim=-1)
            
            # 加权融合
            fused = torch.matmul(attn_weights, V)  # [heads, H*W, N, C//heads]
            fused = fused.permute(2, 0, 3, 1).reshape(N, -1, H, W)[0]  # [C, H, W],取ego视角
            fused_out.append(fused)
        
        return torch.stack(fused_out)

    def forward_collab(self, spatial_features, record_len, affine_matrix, agent_modality_list=None, cam_crop_info=None):
        feature_list = self.get_multiscale_feature(spatial_features)  # 多尺度特征
        fused_features = []
        occ_maps = []
        
        for i, feat in enumerate(feature_list):
            fused_feat = self.gat_fuse(feat, record_len, affine_matrix, i)
            occ_map = getattr(self, f"single_head_{i}")(feat)  # 计算占用图
            fused_features.append(fused_feat)
            occ_maps.append(occ_map)
        
        fused_feature = self.decode_multiscale_feature(fused_features)
        return fused_feature, occ_maps

    def forward_single(self, spatial_features):
        feature_list = self.get_multiscale_feature(spatial_features)
        occ_maps = []
        
        for i, feat in enumerate(feature_list):
            occ_map = getattr(self, f"single_head_{i}")(feat)
            occ_maps.append(occ_map)
        
        fused_feature = self.decode_multiscale_feature(feature_list)
        return fused_feature, occ_maps

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions