import torch.nn as nn
import math as m
import torch
class Embeddings(nn.Module):
def __init__(self, vocab_size, padding_idx, d_model):
super().__init__()
self.d_model = d_model
self.embed = nn.Embedding(vocab_size, d_model, padding_idx = padding_idx)
def forward(self, x):
embedding = self.embed(x)
return embedding * m.sqrt(self.d_model)