import torch
import torch.nn as nn
from decoder_layer import DecoderLayer
from positional_embedding import Embeddings, PositionalEncoding
class Decoder(nn.Module):
def __init__(self, Embedding: Embeddings, d_model, num_heads, num_layers, d_ff, device= 'cpu', dropout = 0.3):
super().__init__()
self.embedding = Embedding
self.PE = PositionalEncoding(d_model, device = device)
self.dropout = nn.Dropout(dropout)
self.decoders = nn.Modulist( [DecoderLayer(
d_model,
num_heads,
d_ff,
dropout) for layer in range(num_layers)])
def forward(self, x, encoder_output, trg_mask, src_mask):
embeddings = self.embedding(x)
encoding = self.PE(embeddings)
for decoder in self.decoders:
encoding, masked_mha_attn_weights, enc_dec_mha_attn_weights = decoder(encoding, encoder_output, trg_mask, src_mask)
# shape(encoding) = [B x TRG_seq_len x D]
# shape(masked_mha_attn_weights) = [B x num_heads x TRG_seq_len x TRG_seq_len]
# shape(enc_dec_mha_attn_weights) = [B x num_heads x TRG_seq_len x SRC_seq_len]
return encoding, masked_mha_attn_weights, enc_dec_mha_attn_weights