import torch
import torch.nn as nn
from encoder_layer import Encoderlayer
from positional_embedding import Embeddings, PositionalEncoding

import torch.nn as nn
import torch
from mha import Mha
from residuallayernorm import ResidualLayerNorm
from pwffn import PWFFN
from positional_embedding import Embeddings, PositionalEncoding

class Encoder(nn.Module):
    def __init__(self, Embedding: Embeddings, d_model,
                 num_heads, num_layers,
                 d_ff, device="cpu", dropout=0.3, efficient_mha=False):
        super().__init__()

        self.embedding = Embedding

        self.PE = PositionalEncoding(
            d_model, device=device)

        self.encoders = nn.ModuleList(

            [ Encoderlayer(
                d_ff, 
                d_model,
                num_heads, 
                dropout = 0.3
            ) for layer in range(num_layers)]
        )

    def forward(self, x):

        embeddings = self.embedding(x)
        encoding = self.PE(embeddings)

        for encoder in self.encoders:
            encoding, encoder_attention_weights = encoder(encoding)

        return  encoding, encoder_attention_weights