'How to use PyTorch's nn.MultiheadAttention

I want to use PyTorch's nn.MultiheadAttention but it doesn't work.
I just want to use the functionality of pytorch for the manual calculated example of attention

I always got an error when trying to run this example.

import torch.nn as nn

embed_dim = 4
num_heads = 1

x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]
x = torch.tensor(x, dtype=torch.float32)

w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)


keys = x @ w_key
querys = x @ w_query
values = x @ w_value

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)


Solution 1:[1]

Try this.

First, your x is a (3x4) matrix. So you need a weight matrix of (4x4) instead.

Seems nn.MultiheadAttention only supports batch mode although the doc said it supports unbatch input. So let's just make your one data point in batch mode via .unsqueeze(0).

embed_dim = 4
num_heads = 1

x = [
  [1, 0, 1, 0], # Seq 1
  [0, 2, 0, 2], # Seq 2
  [1, 1, 1, 1]  # Seq 3
 ]
x = torch.tensor(x, dtype=torch.float32)

w_key = [
  [0, 0, 1, 1],
  [1, 1, 0, 1],
  [0, 1, 0, 1],
  [1, 1, 0, 1]
]
w_query = [
  [1, 0, 1, 1],
  [1, 0, 0, 1],
  [0, 0, 1, 1],
  [0, 1, 1, 1]
]
w_value = [
  [0, 2, 0, 1],
  [0, 3, 0, 1],
  [1, 0, 3, 1],
  [1, 1, 0, 1]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)


keys = (x @ w_key).unsqueeze(0)     # to batch mode
querys = (x @ w_query).unsqueeze(0)
values = (x @ w_value).unsqueeze(0)

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
attn_output, attn_output_weights = multihead_attn(querys, keys, values)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Muscle Guy