import torch.nn as nn
import torch 
from mha import Mha
from pwffn import PWFFN
from residuallayernorm import ResidualLayerNorm

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.3):
        super().__init__()

        self.norm_1 = ResidualLayerNorm(d_model)
        self.norm_2 = ResidualLayerNorm(d_model)
        self.norm_3 = ResidualLayerNorm(d_model)

        
        self.masked_mha = Mha(d_model, num_heads, dropout)
        self.enc_dec_mha = Mha(d_model, num_heads, dropout)
        
        self.ff = PWFFN(d_model, d_ff)

    def forward(self, x, encoder_outputs, trg_mask, src_mask):

        masked_mha , mask_attn_weights = self.masked_mha(x,x,x, mask = trg_mask)

        norm1 = self.norm_1(masked_mha, x)

        enc_dec_mha , enc_dec_attn_weights =  self.enc_dec_mha(norm1, encoder_outputs, encoder_outputs, mask=src_mask)

        norm2 = self.norm_2(enc_dec_mha, norm1)
        # shape(norm2) = [B x TRG_seq_len x D]

        ff = self.ff(norm2)
        norm3 = self.norm_3(ff, norm2)
        # shape(ff) = [B x TRG_seq_len x D]
        # shape(norm3) = [B x TRG_seq_len x D]

        return norm3, mask_attn_weights, enc_dec_attn_weights