import torch.nn as nn
import torch
from mha import Mha
from residuallayernorm import ResidualLayerNorm
from pwffn import PWFFN
from positional_embedding import Embeddings, PositionalEncoding
class Encoderlayer(nn.Module):
def __init__(self, d_ff, d_model, num_heads, dropout = 0.3):
super().__init__()
self.norm_1 = ResidualLayerNorm(d_model, dropout)
self.norm_2 = ResidualLayerNorm(d_model, dropout)
self.mha = Mha(d_model, num_heads, dropout)
self.ff = PWFFN(d_model, d_ff, dropout)
def forward(self, x):
# shape(x) = [batch seq_len d_model]
mha, encoder_attn_weights = self.mha(x)
norm1 = self.norm_1(mha, x)
ff = self.ff(norm1)
norm2 = self.norm_2(ff, norm1)
return norm2, encoder_attn_weights