iT邦幫忙

2021 iThome 鐵人賽

DAY 24
0
AI & Data

Attention到底在關注什麼?系列 第 24

Day 24 利用transformer自己實作一個翻譯程式(六) Masking

  • 分享至 

  • twitterImage
  •  

Masking

需要把填充的部分標記為0,其餘部分標記為1,才不會導致填充的部分被誤認為是輸入

def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)
x = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
create_padding_mask(x)
<tf.Tensor: shape=(3, 1, 1, 5), dtype=float32, numpy=
array([[[[0., 0., 1., 1., 0.]]],

[[[0., 0., 0., 1., 1.]]],

[[[1., 1., 1., 0., 0.]]]], dtype=float32)>
def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)
x = tf.random.uniform((1, 3))
temp = create_look_ahead_mask(x.shape[1])
temp
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
       [0., 0., 1.],
       [0., 0., 0.]], dtype=float32)>

上一篇
Day 23 利用transformer自己實作一個翻譯程式(五) Positional encoding
下一篇
Day 25 利用transformer自己實作一個翻譯程式(七) Scaled dot product attention
系列文
Attention到底在關注什麼?30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言