Commit 76b62e21 authored by Meet Narendra's avatar Meet Narendra 💬

Model file

parent d5f9c39a
import numpy as np
import torch
import torch.nn as nn
from logger import Logger
LOGGER = Logger().logger()
class Model():
'''
Write a pytorch GRU based embedding model
'''
def __init__(self,embedding_size,hidden_size,device,vocab_size,samples) -> None:
'''
Init function
@params
embedding_size: int
hidden_size: int
device: torch.device
'''
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.device = device
self.samples = samples
self.embedding = nn.Embedding(self.vocab_size,self.embedding_size).to(self.device)
self.gru = nn.GRU(self.embedding_size,self.hidden_size).to(self.device)
self.unembedding = nn.Linear(self.hidden_size,self.vocab_size).to(self.device)
self.model = nn.Sequential(self.embedding,self.gru,self.unembedding).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001)
self.loss = nn.CrossEntropyLoss()
LOGGER.info("Model initialized")
return
def init_hidden(self):
'''
Function to initialize hidden state
'''
return torch.zeros(1,self.samples,self.hidden_size).to(self.device)
def get_model(self):
'''
Function to get the model
'''
return self.model
def forward(self,inp):
'''
Function to forward pass
@params
inp: torch.tensor
'''
return self.model(inp)
def train(self,epochs,train_loader):
'''
Function to train the model
@params
epochs: int
train_loader: torch.utils.data.DataLoader
'''
for epoch in range(epochs):
LOGGER.info("Epoch "+str(epoch))
for i,(inp,target) in enumerate(train_loader):
inp = inp.to(self.device)
target = target.to(self.device)
hidden = self.init_hidden()
self.optimizer.zero_grad()
output,hidden = self.model(inp,hidden)
loss = self.loss(output,target)
loss.backward()
self.optimizer.step()
if i%1000==0:
LOGGER.info("Loss "+str(loss.item()))
return
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment