Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
F
FML Project
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Meet Narendra
FML Project
Commits
938b36dc
Commit
938b36dc
authored
Nov 23, 2022
by
Meet Narendra
💬
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Cycle GAN
parent
388a23b3
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
187 additions
and
44 deletions
+187
-44
.gitignore
.gitignore
+2
-0
1703.10593/discriminator.py
1703.10593/discriminator.py
+12
-8
1703.10593/download_data.sh
1703.10593/download_data.sh
+18
-0
1703.10593/generator.py
1703.10593/generator.py
+17
-9
1703.10593/loss.py
1703.10593/loss.py
+4
-12
1703.10593/preprocess.py
1703.10593/preprocess.py
+26
-9
1703.10593/train.py
1703.10593/train.py
+106
-6
1703.10593/utils.py
1703.10593/utils.py
+2
-0
No files found.
.gitignore
View file @
938b36dc
...
@@ -3,3 +3,5 @@
...
@@ -3,3 +3,5 @@
*.csv
*.csv
*.ipynb
*.ipynb
*Logs*
*Logs*
*dataset*
*images*
1703.10593/discriminator.py
View file @
938b36dc
import
torch
import
torch
#Author: @meetdoshi
#Author: @meetdoshi
#Reference: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
class
Discriminator
(
torch
.
nn
.
Module
):
class
Discriminator
(
torch
.
nn
.
Module
):
'''
'''
PatchGAN Discriminator with 70x70 overlapping image patches
PatchGAN Discriminator with 70x70 overlapping image patches
'''
'''
def
__init__
(
self
)
->
None
:
def
__init__
(
self
):
super
()
.
__init__
()
super
(
Discriminator
,
self
)
.
__init__
()
self
.
model
=
torch
.
nn
.
Sequential
(
self
.
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
3
,
64
,
4
,
2
,
1
),
torch
.
nn
.
Conv2d
(
3
,
64
,
4
,
stride
=
2
,
padding
=
1
),
torch
.
nn
.
LeakyReLU
(
0.2
,
True
),
torch
.
nn
.
LeakyReLU
(
0.2
,
True
),
torch
.
nn
.
Conv2d
(
64
,
128
,
4
,
2
,
1
),
torch
.
nn
.
Conv2d
(
64
,
128
,
4
,
stride
=
2
,
padding
=
1
),
torch
.
nn
.
InstanceNorm2d
(
128
),
torch
.
nn
.
InstanceNorm2d
(
128
),
torch
.
nn
.
LeakyReLU
(
0.2
,
True
),
torch
.
nn
.
LeakyReLU
(
0.2
,
True
),
torch
.
nn
.
Conv2d
(
128
,
256
,
4
,
2
,
1
),
torch
.
nn
.
Conv2d
(
128
,
256
,
4
,
stride
=
2
,
padding
=
1
),
torch
.
nn
.
InstanceNorm2d
(
256
),
torch
.
nn
.
InstanceNorm2d
(
256
),
torch
.
nn
.
LeakyReLU
(
0.2
,
True
),
torch
.
nn
.
LeakyReLU
(
0.2
,
True
),
torch
.
nn
.
Conv2d
(
256
,
512
,
4
,
1
,
1
),
torch
.
nn
.
Conv2d
(
256
,
512
,
4
,
padding
=
1
),
torch
.
nn
.
InstanceNorm2d
(
512
),
torch
.
nn
.
InstanceNorm2d
(
512
),
torch
.
nn
.
LeakyReLU
(
0.2
,
True
),
torch
.
nn
.
LeakyReLU
(
0.2
,
True
),
torch
.
nn
.
Conv2d
(
512
,
1
,
4
,
1
,
1
),
torch
.
nn
.
Conv2d
(
512
,
1
,
4
,
padding
=
1
),
)
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
x
=
self
.
model
(
x
)
\ No newline at end of file
x
=
torch
.
nn
.
functional
.
avg_pool2d
(
x
,
x
.
size
()[
2
:])
x
=
torch
.
flatten
(
x
,
1
)
return
x
\ No newline at end of file
1703.10593/download_data.sh
0 → 100644
View file @
938b36dc
#!/bin/bash
mkdir
dataset
cd
dataset
for
FILE
in
"apple2orange"
"summer2winter_yosemite"
"horse2zebra"
"monet2photo"
"cezanne2photo"
"ukiyoe2photo"
"vangogh2photo"
"maps"
"cityscapes"
"facades"
"iphone2dslr_flower"
;
do
URL
=
https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/
${
FILE
}
.zip
ZIP_FILE
=
${
FILE
}
.zip
TARGET_DIR
=
${
FILE
}
wget
${
URL
}
unzip
${
ZIP_FILE
}
rm
${
ZIP_FILE
}
# Adapt to project expected directory heriarchy
mkdir
-p
"
$TARGET_DIR
/train"
"
$TARGET_DIR
/test"
mv
"
$TARGET_DIR
/trainA"
"
$TARGET_DIR
/train/A"
mv
"
$TARGET_DIR
/trainB"
"
$TARGET_DIR
/train/B"
mv
"
$TARGET_DIR
/testA"
"
$TARGET_DIR
/test/A"
mv
"
$TARGET_DIR
/testB"
"
$TARGET_DIR
/test/B"
done
1703.10593/generator.py
View file @
938b36dc
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
#Author: @meetdoshi
#Author: @meetdoshi
#Reference: https://github.com/Lornatang/CycleGAN-PyTorch/blob/master/cyclegan_pytorch/models.py
class
ResidualBlock
(
nn
.
Module
):
class
ResidualBlock
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
):
super
()
.
__init__
()
super
(
ResidualBlock
,
self
)
.
__init__
()
self
.
block
=
nn
.
Sequential
(
self
.
block
=
nn
.
Sequential
(
nn
.
ReflectionPad2d
(
1
),
nn
.
ReflectionPad2d
(
1
),
nn
.
Conv2d
(
256
,
256
,
3
),
nn
.
Conv2d
(
256
,
256
,
3
),
...
@@ -22,33 +23,40 @@ class Generator(torch.nn.Module):
...
@@ -22,33 +23,40 @@ class Generator(torch.nn.Module):
https://arxiv.org/pdf/1603.08155.pdf
https://arxiv.org/pdf/1603.08155.pdf
'''
'''
def
__init__
(
self
)
->
None
:
def
__init__
(
self
):
super
()
.
__init__
()
super
(
Generator
,
self
)
.
__init__
()
self
.
model
=
nn
.
Sequential
(
self
.
model
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
64
,
7
,
1
,
3
),
nn
.
ReflectionPad2d
(
3
),
nn
.
Conv2d
(
3
,
64
,
7
),
nn
.
InstanceNorm2d
(
64
),
nn
.
InstanceNorm2d
(
64
),
nn
.
ReLU
(
True
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
64
,
128
,
3
,
2
,
1
),
nn
.
Conv2d
(
64
,
128
,
3
,
2
,
1
),
nn
.
InstanceNorm2d
(
128
),
nn
.
InstanceNorm2d
(
128
),
nn
.
ReLU
(
True
),
nn
.
ReLU
(
True
),
nn
.
Conv2d
(
128
,
256
,
3
,
2
,
1
),
nn
.
Conv2d
(
128
,
256
,
3
,
2
,
1
),
nn
.
InstanceNorm2d
(
256
),
nn
.
InstanceNorm2d
(
256
),
nn
.
ReLU
(
True
),
nn
.
ReLU
(
True
),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
ResidualBlock
(),
#ResidualBlock(),
ResidualBlock
(),
#ResidualBlock(),
ResidualBlock
(),
#ResidualBlock(),
#ResidualBlock(),
nn
.
ConvTranspose2d
(
256
,
128
,
3
,
2
,
1
,
1
),
nn
.
ConvTranspose2d
(
256
,
128
,
3
,
2
,
1
,
1
),
nn
.
InstanceNorm2d
(
128
),
nn
.
InstanceNorm2d
(
128
),
nn
.
ReLU
(
True
),
nn
.
ReLU
(
True
),
nn
.
ConvTranspose2d
(
128
,
64
,
3
,
2
,
1
,
1
),
nn
.
ConvTranspose2d
(
128
,
64
,
3
,
2
,
1
,
1
),
nn
.
InstanceNorm2d
(
64
),
nn
.
InstanceNorm2d
(
64
),
nn
.
ReLU
(
True
),
nn
.
ReLU
(
True
),
nn
.
ReflectionPad2d
(
3
),
nn
.
ReflectionPad2d
(
3
),
nn
.
Conv2d
(
64
,
3
,
7
),
nn
.
Conv2d
(
64
,
3
,
7
),
nn
.
Tanh
(),
nn
.
Tanh
(),
...
...
1703.10593/loss.py
View file @
938b36dc
...
@@ -5,6 +5,7 @@ LOGGER = Logger().logger()
...
@@ -5,6 +5,7 @@ LOGGER = Logger().logger()
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
#Author: @meetdoshi
#Author: @meetdoshi
device
=
torch
.
device
(
"cpu"
)
class
Loss
:
class
Loss
:
@
staticmethod
@
staticmethod
def
adversarial_G
():
def
adversarial_G
():
...
@@ -12,7 +13,7 @@ class Loss:
...
@@ -12,7 +13,7 @@ class Loss:
@params
@params
@return
@return
'''
'''
return
torch
.
nn
.
BC
ELoss
()
.
to
(
device
)
return
torch
.
nn
.
MS
ELoss
()
.
to
(
device
)
@
staticmethod
@
staticmethod
def
adversarial_D
():
def
adversarial_D
():
...
@@ -20,25 +21,16 @@ class Loss:
...
@@ -20,25 +21,16 @@ class Loss:
@params
@params
@return
@return
'''
'''
return
torch
.
nn
.
BC
ELoss
()
.
to
(
device
)
return
torch
.
nn
.
MS
ELoss
()
.
to
(
device
)
@
staticmethod
@
staticmethod
def
cycle_consistency
_forward
():
def
cycle_consistency
():
'''
'''
@params
@params
@return
@return
'''
'''
return
torch
.
nn
.
L1Loss
()
.
to
(
device
)
return
torch
.
nn
.
L1Loss
()
.
to
(
device
)
@
staticmethod
def
cycle_consistency_backward
():
'''
@params
@return
'''
return
torch
.
nn
.
L1Loss
()
.
to
(
device
)
@
staticmethod
@
staticmethod
def
identity
():
def
identity
():
'''
'''
...
...
1703.10593/preprocess.py
View file @
938b36dc
...
@@ -3,21 +3,38 @@ import torch
...
@@ -3,21 +3,38 @@ import torch
from
logger
import
Logger
from
logger
import
Logger
LOGGER
=
Logger
()
.
logger
()
LOGGER
=
Logger
()
.
logger
()
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
from
torch.utils.data
import
Dataset
,
DataLoader
from
torch.utils.data
import
Dataset
,
DataLoader
from
torchvision
import
transforms
,
utils
import
glob
import
os
from
PIL
import
Image
import
random
#Author: @meetdoshi
#Author: @meetdoshi
class
LoadData
(
Dataset
):
class
LoadData
(
Dataset
):
def
__init__
(
self
,
data
,
transform
=
None
):
def
__init__
(
self
,
data
,
pair
=
False
,
type_data
=
"train"
):
self
.
data
=
data
self
.
data
=
data
self
.
transform
=
transform
self
.
pair
=
pair
self
.
type
=
type_data
self
.
transform
=
transforms
.
Compose
([
transforms
.
Resize
(
int
(
256
*
1.12
),
Image
.
BICUBIC
),
transforms
.
RandomCrop
(
256
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.5
,
0.5
,
0.5
),
(
0.5
,
0.5
,
0.5
)),
])
#LOGGER.info(data)
#LOGGER.info(os.path.join(self.data,f"{type_data}/A"+ "*.jpg"))
self
.
X
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
self
.
data
,
f
"{type_data}/A/"
)
+
"*.jpg"
))
self
.
Y
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
self
.
data
,
f
"{type_data}/B/"
)
+
"*.jpg"
))
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
data
)
return
max
(
len
(
self
.
X
),
len
(
self
.
Y
)
)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
if
torch
.
is_tensor
(
idx
):
if
self
.
pair
:
idx
=
idx
.
tolist
()
return
{
'X'
:
self
.
transform
(
Image
.
open
(
self
.
X
[
idx
%
len
(
self
.
X
)])),
'Y'
:
self
.
transform
(
Image
.
open
(
self
.
Y
[
idx
%
len
(
self
.
Y
)]))}
sample
=
self
.
data
[
idx
]
else
:
if
self
.
transform
:
return
{
'X'
:
self
.
transform
(
Image
.
open
(
self
.
X
[
idx
%
len
(
self
.
X
)])),
'Y'
:
self
.
transform
(
Image
.
open
(
self
.
Y
[
random
.
randint
(
0
,
len
(
self
.
Y
)
-
1
)
%
len
(
self
.
Y
)]))}
sample
=
self
.
transform
(
sample
)
return
sample
\ No newline at end of file
1703.10593/train.py
View file @
938b36dc
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
Dataset
,
DataLoader
from
torchvision.utils
import
save_image
from
logger
import
Logger
from
logger
import
Logger
LOGGER
=
Logger
()
.
logger
()
LOGGER
=
Logger
()
.
logger
()
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
from
discriminator
import
Discriminator
from
discriminator
import
Discriminator
from
generator
import
Generator
from
generator
import
Generator
from
loss
import
Loss
from
loss
import
Loss
...
@@ -10,7 +13,7 @@ from preprocess import LoadData
...
@@ -10,7 +13,7 @@ from preprocess import LoadData
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
utils
import
initialize_weights
from
utils
import
initialize_weights
class
Train
():
class
Train
():
def
__init__
(
self
):
def
__init__
(
self
,
data
=
"dataset/vangogh2photo"
,
pair
=
False
):
'''
'''
@params
@params
@return
@return
...
@@ -33,13 +36,15 @@ class Train():
...
@@ -33,13 +36,15 @@ class Train():
self
.
dis_X_scheduler
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
self
.
dis_X_optim
,
step_size
=
100
,
gamma
=
0.1
)
self
.
dis_X_scheduler
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
self
.
dis_X_optim
,
step_size
=
100
,
gamma
=
0.1
)
self
.
dis_Y_scheduler
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
self
.
dis_Y_optim
,
step_size
=
100
,
gamma
=
0.1
)
self
.
dis_Y_scheduler
=
torch
.
optim
.
lr_scheduler
.
StepLR
(
self
.
dis_Y_optim
,
step_size
=
100
,
gamma
=
0.1
)
self
.
cycle_loss
=
Loss
()
.
cycle_
loss
.
to
(
device
)
self
.
cycle_loss
=
Loss
()
.
cycle_
consistency
self
.
identity_loss
=
Loss
()
.
identity
_loss
.
to
(
device
)
self
.
identity_loss
=
Loss
()
.
identity
self
.
adversarial_loss
=
Loss
()
.
adversarial_
loss
.
to
(
device
)
self
.
adversarial_loss
=
Loss
()
.
adversarial_
G
self
.
losses
=
{
"G"
:
[],
"D"
:
[],
"C"
:
[],
"I"
:
[],
"T"
:
[]}
self
.
losses
=
{
"G"
:
[],
"D"
:
[],
"C"
:
[],
"I"
:
[],
"T"
:
[]}
self
.
dataset
=
LoadData
(
data
=
data
,
pair
=
pair
)
self
.
dataloader
=
DataLoader
(
self
.
dataset
,
batch_size
=
1
,
shuffle
=
True
,
num_workers
=
4
)
def
train
():
def
train
(
self
):
'''
'''
@params
@params
@return
@return
...
@@ -72,3 +77,98 @@ class Train():
...
@@ -72,3 +77,98 @@ class Train():
18. Update Learning Rate
18. Update Learning Rate
'''
'''
adversarial_loss
=
self
.
adversarial_loss
()
cycle_loss
=
self
.
cycle_loss
()
size
=
len
(
self
.
dataloader
)
for
i
,
data
in
tqdm
(
enumerate
(
self
.
dataloader
),
total
=
size
):
real_X
=
data
[
'X'
]
.
to
(
device
)
real_Y
=
data
[
'Y'
]
.
to
(
device
)
batch_size
=
real_X
.
size
(
0
)
real_label
=
torch
.
ones
(
batch_size
,
1
)
.
to
(
device
)
fake_label
=
torch
.
zeros
(
batch_size
,
1
)
.
to
(
device
)
#print(real_label.shape)
# Training the generator
self
.
gen_XY_optim
.
zero_grad
()
fake_gen_X
=
self
.
gen_XY
(
real_Y
)
fake_gen_Y
=
self
.
gen_YX
(
real_X
)
fake_gen_X_label
=
self
.
dis_X
(
fake_gen_X
)
fake_gen_Y_label
=
self
.
dis_Y
(
fake_gen_Y
)
#print(fake_gen_X.shape,fake_gen_X_label.shape)
#print(fake_gen_Y_label,real_label)
loss_gen_Y2X
=
adversarial_loss
(
fake_gen_X_label
,
real_label
)
loss_gen_X2Y
=
adversarial_loss
(
fake_gen_Y_label
,
real_label
)
recovered_Y
=
self
.
gen_XY
(
fake_gen_X
)
recovered_X
=
self
.
gen_YX
(
fake_gen_Y
)
#print(recovered_Y.shape,recovered_X.shape)
loss_cycle_Y2X
=
cycle_loss
(
recovered_Y
,
real_Y
)
loss_cycle_X2Y
=
cycle_loss
(
recovered_X
,
real_X
)
total_loss
=
loss_gen_Y2X
+
loss_gen_X2Y
+
loss_cycle_Y2X
+
loss_cycle_X2Y
#backprop
total_loss
.
backward
()
self
.
gen_XY_optim
.
step
()
# Training the discriminator
# Discriminator for X
self
.
dis_X_optim
.
zero_grad
()
real_X_label
=
self
.
dis_X
(
real_X
)
loss_dis_real_A
=
adversarial_loss
(
real_X_label
,
real_label
)
fake_X_label
=
self
.
dis_X
(
fake_gen_X
.
detach
())
loss_dis_fake_A
=
adversarial_loss
(
fake_X_label
,
fake_label
)
loss_dis_X
=
(
loss_dis_real_A
+
loss_dis_fake_A
)
/
2
#backprop
loss_dis_X
.
backward
()
self
.
dis_X_optim
.
step
()
# Discriminator for Y
self
.
dis_Y_optim
.
zero_grad
()
real_Y_label
=
self
.
dis_Y
(
real_Y
)
loss_dis_real_B
=
adversarial_loss
(
real_Y_label
,
real_label
)
fake_Y_label
=
self
.
dis_Y
(
fake_gen_Y
.
detach
())
loss_dis_fake_B
=
adversarial_loss
(
fake_Y_label
,
fake_label
)
loss_dis_Y
=
(
loss_dis_real_B
+
loss_dis_fake_B
)
/
2
#backprop
loss_dis_Y
.
backward
()
self
.
dis_Y_optim
.
step
()
# Update Logs
self
.
losses
[
"G"
]
.
append
(
total_loss
.
item
())
self
.
losses
[
"D"
]
.
append
((
loss_dis_X
.
item
()
+
loss_dis_Y
.
item
())
/
2
)
self
.
losses
[
"C"
]
.
append
((
loss_cycle_Y2X
.
item
()
+
loss_cycle_X2Y
.
item
())
/
2
)
LOGGER
.
info
(
"Epoch: {} | G: {} | D: {} | C: {}"
.
format
(
epoch
,
total_loss
.
item
(),
(
loss_dis_X
.
item
()
+
loss_dis_Y
.
item
())
/
2
,
(
loss_cycle_Y2X
.
item
()
+
loss_cycle_X2Y
.
item
())
/
2
))
# Save Image
if
i
%
100
==
0
:
save_image
(
fake_gen_X
,
"images/{}_{}_fake_X.png"
.
format
(
epoch
,
i
))
save_image
(
fake_gen_Y
,
"images/{}_{}_fake_Y.png"
.
format
(
epoch
,
i
))
save_image
(
real_X
,
"images/{}_{}_real_X.png"
.
format
(
epoch
,
i
))
save_image
(
real_Y
,
"images/{}_{}_real_Y.png"
.
format
(
epoch
,
i
))
# Update Learning Rate
self
.
gen_XY_scheduler
.
step
()
self
.
dis_X_scheduler
.
step
()
self
.
dis_Y_scheduler
.
step
()
# Save weights
torch
.
save
(
self
.
gen_XY
.
state_dict
(),
"weights/gen_XY.pth"
)
torch
.
save
(
self
.
gen_YX
.
state_dict
(),
"weights/gen_YX.pth"
)
torch
.
save
(
self
.
dis_X
.
state_dict
(),
"weights/dis_X.pth"
)
torch
.
save
(
self
.
dis_Y
.
state_dict
(),
"weights/dis_Y.pth"
)
if
__name__
==
"__main__"
:
train
=
Train
()
train
.
train
()
1703.10593/utils.py
View file @
938b36dc
...
@@ -4,6 +4,8 @@ from logger import Logger
...
@@ -4,6 +4,8 @@ from logger import Logger
LOGGER
=
Logger
()
.
logger
()
LOGGER
=
Logger
()
.
logger
()
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
#Author: @meetdoshi
#Author: @meetdoshi
def
initialize_weights
(
model
):
def
initialize_weights
(
model
):
'''
'''
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment