import torch.nn as nn
import torch
from mha import Mha
from residuallayernorm import ResidualLayerNorm
from pwffn import PWFFN
from positional_embedding import Embeddings, PositionalEncoding

class Encoderlayer(nn.Module):
    def __init__(self, d_ff, d_model, num_heads, dropout = 0.3): 
        super().__init__()

        self.norm_1 = ResidualLayerNorm(d_model, dropout)
        self.norm_2 = ResidualLayerNorm(d_model, dropout)

        self.mha = Mha(d_model, num_heads, dropout)

        self.ff = PWFFN(d_model, d_ff, dropout)

    def forward(self, x):
        # shape(x) = [batch seq_len d_model]

        mha, encoder_attn_weights = self.mha(x)
        norm1 = self.norm_1(mha, x)

        ff = self.ff(norm1)
        norm2 = self.norm_2(ff, norm1)

        return norm2, encoder_attn_weights