DAY 26
0
AI & Data

## Day 26 利用transformer自己實作一個翻譯程式(八) Multi-head attention

code的詳細解說之後會補上，由於我自己也還在讀這方面的內容，因此可能需要一點時間

``````class MultiHeadAttention(tf.keras.layers.Layer):
self.d_model = d_model

assert d_model % self.num_heads == 0

self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)

self.dense = tf.keras.layers.Dense(d_model)

"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])

def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]

q = self.wq(q)  # (batch_size, seq_len, d_model)
k = self.wk(k)  # (batch_size, seq_len, d_model)
v = self.wv(v)  # (batch_size, seq_len, d_model)

# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(

scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

return output, attention_weights
``````
``````temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape
``````
``````(TensorShape([1, 60, 512]), TensorShape([1, 8, 60, 60]))
``````