Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
CS626-Project
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Meet Narendra
CS626-Project
Commits
76b62e21
Commit
76b62e21
authored
Nov 07, 2022
by
Meet Narendra
💬
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Model file
parent
d5f9c39a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
0 deletions
+73
-0
model.py
model.py
+73
-0
No files found.
model.py
0 → 100644
View file @
76b62e21
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
logger
import
Logger
LOGGER
=
Logger
()
.
logger
()
class
Model
():
'''
Write a pytorch GRU based embedding model
'''
def
__init__
(
self
,
embedding_size
,
hidden_size
,
device
,
vocab_size
,
samples
)
->
None
:
'''
Init function
@params
embedding_size: int
hidden_size: int
device: torch.device
'''
self
.
embedding_size
=
embedding_size
self
.
hidden_size
=
hidden_size
self
.
vocab_size
=
vocab_size
self
.
device
=
device
self
.
samples
=
samples
self
.
embedding
=
nn
.
Embedding
(
self
.
vocab_size
,
self
.
embedding_size
)
.
to
(
self
.
device
)
self
.
gru
=
nn
.
GRU
(
self
.
embedding_size
,
self
.
hidden_size
)
.
to
(
self
.
device
)
self
.
unembedding
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
vocab_size
)
.
to
(
self
.
device
)
self
.
model
=
nn
.
Sequential
(
self
.
embedding
,
self
.
gru
,
self
.
unembedding
)
.
to
(
self
.
device
)
self
.
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
0.001
)
self
.
loss
=
nn
.
CrossEntropyLoss
()
LOGGER
.
info
(
"Model initialized"
)
return
def
init_hidden
(
self
):
'''
Function to initialize hidden state
'''
return
torch
.
zeros
(
1
,
self
.
samples
,
self
.
hidden_size
)
.
to
(
self
.
device
)
def
get_model
(
self
):
'''
Function to get the model
'''
return
self
.
model
def
forward
(
self
,
inp
):
'''
Function to forward pass
@params
inp: torch.tensor
'''
return
self
.
model
(
inp
)
def
train
(
self
,
epochs
,
train_loader
):
'''
Function to train the model
@params
epochs: int
train_loader: torch.utils.data.DataLoader
'''
for
epoch
in
range
(
epochs
):
LOGGER
.
info
(
"Epoch "
+
str
(
epoch
))
for
i
,(
inp
,
target
)
in
enumerate
(
train_loader
):
inp
=
inp
.
to
(
self
.
device
)
target
=
target
.
to
(
self
.
device
)
hidden
=
self
.
init_hidden
()
self
.
optimizer
.
zero_grad
()
output
,
hidden
=
self
.
model
(
inp
,
hidden
)
loss
=
self
.
loss
(
output
,
target
)
loss
.
backward
()
self
.
optimizer
.
step
()
if
i
%
1000
==
0
:
LOGGER
.
info
(
"Loss "
+
str
(
loss
.
item
()))
return
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment