import torch
import torch.nn as nn
from encoder_layer import Encoderlayer
from positional_embedding import Embeddings, PositionalEncoding
import torch.nn as nn
import torch
from mha import Mha
from residuallayernorm import ResidualLayerNorm
from pwffn import PWFFN
from positional_embedding import Embeddings, PositionalEncoding
class Encoder(nn.Module):
def __init__(self, Embedding: Embeddings, d_model,
num_heads, num_layers,
d_ff, device="cpu", dropout=0.3, efficient_mha=False):
super().__init__()
self.embedding = Embedding
self.PE = PositionalEncoding(
d_model, device=device)
self.encoders = nn.ModuleList(
[ Encoderlayer(
d_ff,
d_model,
num_heads,
dropout = 0.3
) for layer in range(num_layers)]
)
def forward(self, x):
embeddings = self.embedding(x)
encoding = self.PE(embeddings)
for encoder in self.encoders:
encoding, encoder_attention_weights = encoder(encoding)
return encoding, encoder_attention_weights