Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
F
FML Project
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Meet Narendra
FML Project
Commits
a03af97b
Commit
a03af97b
authored
Nov 23, 2022
by
Meet Narendra
💬
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Cycle gans minor modifications
parent
938b36dc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
34 additions
and
25 deletions
+34
-25
1703.10593/generator.py
1703.10593/generator.py
+4
-4
1703.10593/loss.py
1703.10593/loss.py
+0
-9
1703.10593/preprocess.py
1703.10593/preprocess.py
+0
-1
1703.10593/train.py
1703.10593/train.py
+30
-10
1703.10593/utils.py
1703.10593/utils.py
+0
-1
No files found.
1703.10593/generator.py
View file @
a03af97b
...
...
@@ -44,10 +44,10 @@ class Generator(torch.nn.Module):
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
#
ResidualBlock(),
#
ResidualBlock(),
#
ResidualBlock(),
#
ResidualBlock(),
nn
.
ConvTranspose2d
(
256
,
128
,
3
,
2
,
1
,
1
),
nn
.
InstanceNorm2d
(
128
),
...
...
1703.10593/loss.py
View file @
a03af97b
...
...
@@ -5,7 +5,6 @@ LOGGER = Logger().logger()
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
#Author: @meetdoshi
device
=
torch
.
device
(
"cpu"
)
class
Loss
:
@
staticmethod
def
adversarial_G
():
...
...
@@ -15,14 +14,6 @@ class Loss:
'''
return
torch
.
nn
.
MSELoss
()
.
to
(
device
)
@
staticmethod
def
adversarial_D
():
'''
@params
@return
'''
return
torch
.
nn
.
MSELoss
()
.
to
(
device
)
@
staticmethod
def
cycle_consistency
():
'''
...
...
1703.10593/preprocess.py
View file @
a03af97b
...
...
@@ -3,7 +3,6 @@ import torch
from
logger
import
Logger
LOGGER
=
Logger
()
.
logger
()
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
from
torch.utils.data
import
Dataset
,
DataLoader
from
torchvision
import
transforms
,
utils
import
glob
...
...
1703.10593/train.py
View file @
a03af97b
...
...
@@ -5,19 +5,23 @@ from torchvision.utils import save_image
from
logger
import
Logger
LOGGER
=
Logger
()
.
logger
()
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
from
discriminator
import
Discriminator
from
generator
import
Generator
from
loss
import
Loss
from
preprocess
import
LoadData
from
tqdm
import
tqdm
from
utils
import
initialize_weights
LOGGER
.
info
(
"Cuda status: "
+
str
(
device
))
torch
.
cuda
.
empty_cache
()
class
Train
():
def
__init__
(
self
,
data
=
"dataset/vangogh2photo"
,
pair
=
False
):
def
__init__
(
self
,
data
=
"dataset/vangogh2photo"
,
pair
=
False
,
epochs
=
200
,
batch_size
=
1
):
'''
@params
@return
'''
self
.
epochs
=
epochs
self
.
batch_size
=
batch_size
self
.
gen_XY
=
Generator
()
.
to
(
device
)
self
.
gen_YX
=
Generator
()
.
to
(
device
)
self
.
dis_X
=
Discriminator
()
.
to
(
device
)
...
...
@@ -42,14 +46,16 @@ class Train():
self
.
losses
=
{
"G"
:
[],
"D"
:
[],
"C"
:
[],
"I"
:
[],
"T"
:
[]}
self
.
dataset
=
LoadData
(
data
=
data
,
pair
=
pair
)
self
.
dataloader
=
DataLoader
(
self
.
dataset
,
batch_size
=
1
,
shuffle
=
True
,
num_workers
=
4
)
self
.
dataloader
=
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size
,
shuffle
=
True
,
num_workers
=
4
)
def
train
(
self
):
'''
@params
@return
'''
EPOCHS
=
200
EPOCHS
=
self
.
epochs
batch_size
=
self
.
batch_size
for
epoch
in
range
(
EPOCHS
):
'''
Steps:
...
...
@@ -79,10 +85,14 @@ class Train():
'''
adversarial_loss
=
self
.
adversarial_loss
()
cycle_loss
=
self
.
cycle_loss
()
identity_loss
=
self
.
identity_loss
()
size
=
len
(
self
.
dataloader
)
for
i
,
data
in
tqdm
(
enumerate
(
self
.
dataloader
),
total
=
size
):
torch
.
cuda
.
empty_cache
()
real_X
=
data
[
'X'
]
.
to
(
device
)
real_Y
=
data
[
'Y'
]
.
to
(
device
)
#print(real_X.shape)
#print(real_Y.shape)
batch_size
=
real_X
.
size
(
0
)
real_label
=
torch
.
ones
(
batch_size
,
1
)
.
to
(
device
)
fake_label
=
torch
.
zeros
(
batch_size
,
1
)
.
to
(
device
)
...
...
@@ -91,8 +101,15 @@ class Train():
# Training the generator
self
.
gen_XY_optim
.
zero_grad
()
fake_gen_X
=
self
.
gen_XY
(
real_Y
)
fake_gen_Y
=
self
.
gen_YX
(
real_X
)
identity_X
=
self
.
gen_YX
(
real_X
)
identity_Y
=
self
.
gen_XY
(
real_Y
)
loss_iden_X
=
identity_loss
(
identity_X
,
real_X
)
*
10
loss_iden_Y
=
identity_loss
(
identity_Y
,
real_Y
)
*
10
fake_gen_X
=
self
.
gen_YX
(
real_Y
)
fake_gen_Y
=
self
.
gen_XY
(
real_X
)
fake_gen_X_label
=
self
.
dis_X
(
fake_gen_X
)
fake_gen_Y_label
=
self
.
dis_Y
(
fake_gen_Y
)
...
...
@@ -106,10 +123,10 @@ class Train():
#print(recovered_Y.shape,recovered_X.shape)
loss_cycle_Y2X
=
cycle_loss
(
recovered_Y
,
real_Y
)
loss_cycle_X2Y
=
cycle_loss
(
recovered_X
,
real_X
)
loss_cycle_Y2X
=
cycle_loss
(
recovered_Y
,
real_Y
)
*
20
loss_cycle_X2Y
=
cycle_loss
(
recovered_X
,
real_X
)
*
20
total_loss
=
loss_gen_Y2X
+
loss_gen_X2Y
+
loss_cycle_Y2X
+
loss_cycle_X2Y
total_loss
=
loss_gen_Y2X
+
loss_gen_X2Y
+
loss_cycle_Y2X
+
loss_cycle_X2Y
+
loss_iden_X
+
loss_iden_Y
#backprop
total_loss
.
backward
()
self
.
gen_XY_optim
.
step
()
...
...
@@ -149,7 +166,7 @@ class Train():
self
.
losses
[
"G"
]
.
append
(
total_loss
.
item
())
self
.
losses
[
"D"
]
.
append
((
loss_dis_X
.
item
()
+
loss_dis_Y
.
item
())
/
2
)
self
.
losses
[
"C"
]
.
append
((
loss_cycle_Y2X
.
item
()
+
loss_cycle_X2Y
.
item
())
/
2
)
LOGGER
.
info
(
"Epoch: {} |
G: {} | D: {} | C: {}"
.
format
(
epoch
,
total_loss
.
item
(),
(
loss_dis_X
.
item
()
+
loss_dis_Y
.
item
())
/
2
,
(
loss_cycle_Y2X
.
item
()
+
loss_cycle_X2Y
.
item
())
/
2
))
LOGGER
.
info
(
"Epoch: {} |
i: {} | G: {} | D: {} | C: {}"
.
format
(
epoch
,
i
,
total_loss
.
item
(),
(
loss_dis_X
.
item
()
+
loss_dis_Y
.
item
())
/
2
,
(
loss_cycle_Y2X
.
item
()
+
loss_cycle_X2Y
.
item
())
/
2
))
# Save Image
if
i
%
100
==
0
:
...
...
@@ -169,6 +186,9 @@ class Train():
torch
.
save
(
self
.
dis_X
.
state_dict
(),
"weights/dis_X.pth"
)
torch
.
save
(
self
.
dis_Y
.
state_dict
(),
"weights/dis_Y.pth"
)
#Save losses
torch
.
save
(
self
.
losses
,
"losses.pt"
)
if
__name__
==
"__main__"
:
train
=
Train
()
train
.
train
()
1703.10593/utils.py
View file @
a03af97b
...
...
@@ -4,7 +4,6 @@ from logger import Logger
LOGGER
=
Logger
()
.
logger
()
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
#Author: @meetdoshi
def
initialize_weights
(
model
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment