V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
V2EX 提问指南
Richard14
V2EX  ›  问与答

Bert 神经网络结构中是否使用了多头自注意力机制?

  •  
  •   Richard14 · 2022-06-30 00:48:17 +08:00 · 981 次点击
    这是一个创建于 889 天前的主题,其中的信息可能已经有所发展或是发生改变。

    学习完 RNN 之后学习效果更好的 bert ,从网上 trasformers 的预训练库里加载了一个 bert 的模型,但是将其打印出来的结构类似下面这样:

    BertForSequenceClassification(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(21128, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=768, out_features=768, bias=True)
                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=768, out_features=3072, bias=True)
                (intermediate_act_fn): GELUActivation()
              )
              (output): BertOutput(
                (dense): Linear(in_features=3072, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (1): BertLayer(...)
            (2): BertLayer(...)
            (3): BertLayer(...)
            (4): BertLayer(...)
            (5): BertLayer(...)
            (6): BertLayer(...)
            (7): BertLayer(...)
            (8): BertLayer(...)
            (9): BertLayer(...)
            (10): BertLayer(...)
            (11): BertLayer(...)
          )
        )
        (pooler): BertPooler(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (activation): Tanh()
        )
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (classifier): Linear(in_features=768, out_features=14, bias=True)
    )
    

    可以观察到它的实现里 embedding 完成之后是十二个 bertlayer ,而每个 bertlayer 的内容,输入长度是 768 ,分别生成长度为 768 的三个 kqv ,那么这应该是一个完整的自注意力要素,multihead 的成分在哪里呢?

    还是说我理解错了,它打印出来这个结果意思是 embedding 完事之后有 12 个平行的 self-attention 并行计算,而不是顺序结算?

    看到网上资料里都写 bert 是有 multihead selfattention 的

    6 条回复    2022-07-06 05:20:22 +08:00
    MeePawn666
        1
    MeePawn666  
       2022-06-30 01:22:03 +08:00 via Android
    multi head 是计算策略,你看一下源码就知道了。self attention 就是 parallel 的不是 sequential 的,这也是它比 rnn 好的原因之一。
    Richard14
        3
    Richard14  
    OP
       2022-06-30 04:16:22 +08:00
    @MeePawn666
    @Xs0ul 好的,所以它只是打印结果看起来像是线性的,实际上各层之间只有掐头去尾的关系,中间是并行的?
    Xs0ul
        4
    Xs0ul  
       2022-06-30 04:38:33 +08:00   ❤️ 2
    可能没说清楚,你可以先看一眼可选的 config: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertConfig
    里面有两个参数分别是 num_hidden_layers 和 num_attention_heads ,而它们的默认值刚好都是 12.

    你说打印出来看起来像是线性的 12 层,这是个 num_hidden_layers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L577, 可以从源代码看出来确实是线性进行的。

    而 multihead 的并行,是我上面发的那个,包括往下几行的 forward 。这个在打印的结构里是没有体现的
    Richard14
        5
    Richard14  
    OP
       2022-07-01 19:03:11 +08:00
    @Xs0ul
    @MeePawn666 大佬们再问一下,paper 的预训练部分没有看懂,他说是预训练工作是随机屏蔽 15%的词做完形填空,那么具体的输出是什么形式呢,比如用一个长度为 1000 的向量表示词的话,输入是[batch_size, 句子最大长度, 1000],输出就是[batch_size, 1000]这样吗?不过既然是 15%的话,那一句话里不是有可能会多于一个被屏蔽词,这又怎么表现在输出当中呢。另外如果按照上述输出,是否意味着网络需要自己学习哪个位置是被屏蔽的,这是不是让网络学习抽取不必要特征了。。
    Xs0ul
        6
    Xs0ul  
       2022-07-06 05:20:22 +08:00
    我记得输出是 [batch_size, 句子最大长度,vocab_size],也就是输出填好以后的整段话。但可以把非 mask 的位置上的 loss 屏蔽掉不参与 BP
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   6027 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 26ms · UTC 02:08 · PVG 10:08 · LAX 18:08 · JFK 21:08
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.