import torch
import torch.nn as nn
import math as m
import torch.nn.functional as F

  ### size of q, k , v 
    ### d_model (int) – the number of expected features in the encoder/decoder inputs AKA emdedding size 
    ### num_head default = 8 

    #### The dropout rate is set to 20%:
    # meaning one in five inputs will be randomly excluded from each update cycle.

class Mha(nn.Module):
    def __init__(self, num_heads = 2, d_model = 4, dropout = 0.3):
        super().__init__()

        self.d_model = d_model 
        self.num_heads = num_heads

        self.d = d_model // num_heads 

        self.dropout = nn.Dropout(dropout)

        self.linear_Qs = nn.ModuleList([nn.Linear(d_model, self.d)
        for _ in range(num_heads)])
        

        self.linear_Ks = nn.ModuleList([nn.Linear(d_model, self.d)
        for _ in range(num_heads)])

        self.linear_Vs = nn.ModuleList([ nn.Linear(d_model, self.d)
        for _ in range(num_heads)])

        self.mha_linear = nn.Linear(d_model, d_model)

    
    def scaled_dot_product_attention(self, Q, K, V, mask = None):
        ## shape Q , K, V: # of batch * [ seq * self.d (d_model / num_heads))    ]
        ### q matmul k => batch * [ seq_len * seq_len ] (attention_wieght)
        ### output -- batch * [ seq * self.d]

        Q_K_matmul = torch.matmul(Q, K.permute(0,2,1))

        scores = Q_K_matmul / m.sqrt(self.d)

        if mask is not None:
            scores = scores.masked_fill(mask ==0 , -1)

        attention_weight = F.softmax(scores, dim = -1)

        output = torch.matmul(attention_weight, V)

        return output, attention_weight 

    
    def forward(self,q,k,v):
        # shape x = batch * [ seq_len * d_model  ]

        Q = [linear_Q(q)   for linear_Q in self.linear_Qs]
        K = [linear_K(k)  for linear_K in self.linear_Ks]
        V = [linear_V(v)  for linear_V in self.linear_Vs]

        output_per_head = []
        attn_weights_per_head = []

        for Q_, K_, V_ in zip(Q,K,V):
            output, attn_weight = self.scaled_dot_product_attention(Q_, K_, V_)
            output_per_head.append(output)
            attn_weights_per_head.append(attn_weight)

        output = torch.cat( 
            output_per_head, -1
        )
       ### num_head [ ]
        attn_weights = torch.stack(attn_weights_per_head).permute(1,0,2,3)

        projection = self.dropout(self.mha_linear(output))

        return projection, attn_weight