import torch
import torch.nn as nn
import math as m
import torch.nn.functional as F
### size of q, k , v
### d_model (int) – the number of expected features in the encoder/decoder inputs AKA emdedding size
### num_head default = 8
#### The dropout rate is set to 20%:
# meaning one in five inputs will be randomly excluded from each update cycle.
class Mha(nn.Module):
def __init__(self, num_heads = 2, d_model = 4, dropout = 0.3):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d = d_model // num_heads
self.dropout = nn.Dropout(dropout)
self.linear_Qs = nn.ModuleList([nn.Linear(d_model, self.d)
for _ in range(num_heads)])
self.linear_Ks = nn.ModuleList([nn.Linear(d_model, self.d)
for _ in range(num_heads)])
self.linear_Vs = nn.ModuleList([ nn.Linear(d_model, self.d)
for _ in range(num_heads)])
self.mha_linear = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask = None):
## shape Q , K, V: # of batch * [ seq * self.d (d_model / num_heads)) ]
### q matmul k => batch * [ seq_len * seq_len ] (attention_wieght)
### output -- batch * [ seq * self.d]
Q_K_matmul = torch.matmul(Q, K.permute(0,2,1))
scores = Q_K_matmul / m.sqrt(self.d)
if mask is not None:
scores = scores.masked_fill(mask ==0 , -1)
attention_weight = F.softmax(scores, dim = -1)
output = torch.matmul(attention_weight, V)
return output, attention_weight
def forward(self,q,k,v):
# shape x = batch * [ seq_len * d_model ]
Q = [linear_Q(q) for linear_Q in self.linear_Qs]
K = [linear_K(k) for linear_K in self.linear_Ks]
V = [linear_V(v) for linear_V in self.linear_Vs]
output_per_head = []
attn_weights_per_head = []
for Q_, K_, V_ in zip(Q,K,V):
output, attn_weight = self.scaled_dot_product_attention(Q_, K_, V_)
output_per_head.append(output)
attn_weights_per_head.append(attn_weight)
output = torch.cat(
output_per_head, -1
)
### num_head [ ]
attn_weights = torch.stack(attn_weights_per_head).permute(1,0,2,3)
projection = self.dropout(self.mha_linear(output))
return projection, attn_weight