import torch.nn as nn
import torch
from mha import Mha
from pwffn import PWFFN
from residuallayernorm import ResidualLayerNorm
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.3):
super().__init__()
self.norm_1 = ResidualLayerNorm(d_model)
self.norm_2 = ResidualLayerNorm(d_model)
self.norm_3 = ResidualLayerNorm(d_model)
self.masked_mha = Mha(d_model, num_heads, dropout)
self.enc_dec_mha = Mha(d_model, num_heads, dropout)
self.ff = PWFFN(d_model, d_ff)
def forward(self, x, encoder_outputs, trg_mask, src_mask):
masked_mha , mask_attn_weights = self.masked_mha(x,x,x, mask = trg_mask)
norm1 = self.norm_1(masked_mha, x)
enc_dec_mha , enc_dec_attn_weights = self.enc_dec_mha(norm1, encoder_outputs, encoder_outputs, mask=src_mask)
norm2 = self.norm_2(enc_dec_mha, norm1)
# shape(norm2) = [B x TRG_seq_len x D]
ff = self.ff(norm2)
norm3 = self.norm_3(ff, norm2)
# shape(ff) = [B x TRG_seq_len x D]
# shape(norm3) = [B x TRG_seq_len x D]
return norm3, mask_attn_weights, enc_dec_attn_weights