-
Notifications
You must be signed in to change notification settings - Fork 23
Open
Description
Thanks for sharing the code. However, I'm quite confused for the code of QGM as the naming of the code is a little different from the original paper(if I understand it correctly...)
I think the code for that module is defined in function lang_tf_enc of model/transformer_model.py
def lang_tf_enc(vision_input,
lang_input,
head_num=8,
hidden_dim=256):
decoder_embed_lang = TrigPosEmbedding(
mode=TrigPosEmbedding.MODE_ADD,
name='Fusion-Lang-Decoder-Embedding',
)(lang_input)
decoder_embed_vis = TrigPosEmbedding(
mode=TrigPosEmbedding.MODE_ADD,
name='Fusion-Vis-Decoder-Embedding',
)(vision_input)
q_inp = L.Dense(hidden_dim, activation='relu')(decoder_embed_vis)
k_inp = L.Dense(hidden_dim, activation='relu')(decoder_embed_lang)
v_inp = L.Dense(hidden_dim, activation='relu')(decoder_embed_lang)
decoded_layer = MultiHeadAttention(head_num=head_num)(
[q_inp, k_inp, v_inp])
add_layer = L.Add(name='Fusion-Add')([decoded_layer, vision_input])
return add_layer
As the figure 4 suggests, the input vision features should be the raw vision features extracted from the vision backbone network. Yet the input for this function is features fused by vision & language features Fm_query(in function make_multitask_braches of model/vlt_model.py):
def make_multitask_braches(Fv, fq, fq_word, config):
# fq: bs, 1024
# fq_word: bs, 15, 1024
Fm = simple_fusion(Fv[0], fq, config.jemb_dim) # 13, 13, 1024
Fm_mid_query = up_proj_cat_proj(Fm, Fv[1], K.int_shape(Fv[1],)[-1], K.int_shape(Fm)[-1]//2) # 26, 26, 512
Fm_query = pool_proj_cat_proj(Fm_mid_query, Fv[2], K.int_shape(Fv[2])[-1], K.int_shape(Fm)[-1]//2) # 26, 26, 512
Fm_mid_tf = proj_cat(Fm_query, Fm_mid_query, K.int_shape(Fm)[-1]//2) # 26, 26, 1024
F_tf = up_proj_cat_proj(Fm, Fm_mid_tf, K.int_shape(Fm)[-1] // 2)
F_tf = V.DarknetConv2D_BN_Leaky(config.hidden_dim, (1, 1))(F_tf)
# Fm_query: bs, Hm, Wm, C (None, 26, 26, 512)
# Fm_top_tf : bs, Hc, Wc, C (None, 26, 26, 512)
query_out = vlt_querynet(Fm_query, config)
mask_out = vlt_transformer(F_tf, fq_word, query_out, config)
mask_out = vlt_postproc(mask_out, Fm_query, config)
return mask_out
Can you tell me if I got it wrong? Thanks for your great patience.
Metadata
Metadata
Assignees
Labels
No labels