Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
F
FML Project
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Meet Narendra
FML Project
Commits
f462bad0
Commit
f462bad0
authored
Nov 23, 2022
by
Meet Narendra
💬
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Changed variable names
parent
a03af97b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
1703.10593/train.py
1703.10593/train.py
+9
-6
No files found.
1703.10593/train.py
View file @
f462bad0
...
...
@@ -13,12 +13,15 @@ from tqdm import tqdm
from
utils
import
initialize_weights
LOGGER
.
info
(
"Cuda status: "
+
str
(
device
))
torch
.
cuda
.
empty_cache
()
import
os
class
Train
():
def
__init__
(
self
,
data
=
"dataset/vangogh2photo"
,
pair
=
False
,
epochs
=
200
,
batch_size
=
1
):
'''
@params
@return
'''
if
not
os
.
path
.
exists
(
"images"
):
os
.
mkdir
(
"images"
)
self
.
epochs
=
epochs
self
.
batch_size
=
batch_size
...
...
@@ -137,12 +140,12 @@ class Train():
self
.
dis_X_optim
.
zero_grad
()
real_X_label
=
self
.
dis_X
(
real_X
)
loss_dis_real_
A
=
adversarial_loss
(
real_X_label
,
real_label
)
loss_dis_real_
X
=
adversarial_loss
(
real_X_label
,
real_label
)
fake_X_label
=
self
.
dis_X
(
fake_gen_X
.
detach
())
loss_dis_fake_
A
=
adversarial_loss
(
fake_X_label
,
fake_label
)
loss_dis_fake_
X
=
adversarial_loss
(
fake_X_label
,
fake_label
)
loss_dis_X
=
(
loss_dis_real_
A
+
loss_dis_fake_A
)
/
2
loss_dis_X
=
(
loss_dis_real_
X
+
loss_dis_fake_X
)
/
2
#backprop
loss_dis_X
.
backward
()
self
.
dis_X_optim
.
step
()
...
...
@@ -151,12 +154,12 @@ class Train():
self
.
dis_Y_optim
.
zero_grad
()
real_Y_label
=
self
.
dis_Y
(
real_Y
)
loss_dis_real_
B
=
adversarial_loss
(
real_Y_label
,
real_label
)
loss_dis_real_
Y
=
adversarial_loss
(
real_Y_label
,
real_label
)
fake_Y_label
=
self
.
dis_Y
(
fake_gen_Y
.
detach
())
loss_dis_fake_
B
=
adversarial_loss
(
fake_Y_label
,
fake_label
)
loss_dis_fake_
Y
=
adversarial_loss
(
fake_Y_label
,
fake_label
)
loss_dis_Y
=
(
loss_dis_real_
B
+
loss_dis_fake_B
)
/
2
loss_dis_Y
=
(
loss_dis_real_
Y
+
loss_dis_fake_Y
)
/
2
#backprop
loss_dis_Y
.
backward
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment