Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e9a6a5af
Commit
e9a6a5af
authored
Feb 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
train/ directory
parent
2264b5a3
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
365 additions
and
21 deletions
+365
-21
example_cifar10.py
example_cifar10.py
+2
-2
example_mnist.py
example_mnist.py
+3
-3
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+0
-1
tensorpack/train/__init__.py
tensorpack/train/__init__.py
+20
-0
tensorpack/train/base.py
tensorpack/train/base.py
+107
-0
tensorpack/train/config.py
tensorpack/train/config.py
+57
-0
tensorpack/train/train.py
tensorpack/train/train.py
+163
-0
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+0
-2
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+13
-13
No files found.
example_cifar10.py
View file @
e9a6a5af
...
@@ -8,7 +8,7 @@ import argparse
...
@@ -8,7 +8,7 @@ import argparse
import
numpy
as
np
import
numpy
as
np
import
os
import
os
from
tensorpack.train
import
TrainConfig
,
start_train
from
tensorpack.train
import
TrainConfig
,
QueueInputTrainer
from
tensorpack.models
import
*
from
tensorpack.models
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
...
@@ -158,4 +158,4 @@ if __name__ == '__main__':
...
@@ -158,4 +158,4 @@ if __name__ == '__main__':
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
start_train
(
config
)
QueueInputTrainer
(
config
)
.
train
(
)
example_mnist.py
View file @
e9a6a5af
...
@@ -10,7 +10,7 @@ import numpy as np
...
@@ -10,7 +10,7 @@ import numpy as np
import
os
,
sys
import
os
,
sys
import
argparse
import
argparse
from
tensorpack.train
import
TrainConfig
,
start_train
from
tensorpack.train
import
TrainConfig
,
SimpleTrainer
from
tensorpack.models
import
*
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.symbolic_functions
import
*
...
@@ -92,7 +92,7 @@ def get_config():
...
@@ -92,7 +92,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
20
#
step_per_epoch = 20
# prepare session
# prepare session
sess_config
=
get_default_sess_config
()
sess_config
=
get_default_sess_config
()
...
@@ -131,5 +131,5 @@ if __name__ == '__main__':
...
@@ -131,5 +131,5 @@ if __name__ == '__main__':
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
start_train
(
config
)
SimpleTrainer
(
config
)
.
train
(
)
tensorpack/models/batch_norm.py
View file @
e9a6a5af
...
@@ -47,7 +47,6 @@ def BatchNorm(x, is_training, gamma_init=1.0):
...
@@ -47,7 +47,6 @@ def BatchNorm(x, is_training, gamma_init=1.0):
x
.
set_shape
(
hack_shape
)
x
.
set_shape
(
hack_shape
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
name
=
'moments'
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
name
=
'moments'
)
print
batch_mean
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
0.999
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
0.999
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
...
...
tensorpack/train/__init__.py
0 → 100644
View file @
e9a6a5af
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
pkgutil
import
walk_packages
import
os
import
os.path
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
for
_
,
module_name
,
_
in
walk_packages
(
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
tensorpack/train/base.py
0 → 100644
View file @
e9a6a5af
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
abc
import
ABCMeta
import
tqdm
import
re
from
.config
import
TrainConfig
from
..utils
import
*
from
..callbacks
import
StatHolder
from
..utils.modelutils
import
describe_model
__all__
=
[
'Trainer'
]
class
Trainer
(
object
):
__metaclass__
=
ABCMeta
def
__init__
(
self
,
config
):
"""
Config: a `TrainConfig` instance
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
tf
.
add_to_collection
(
MODEL_KEY
,
config
.
model
)
@
abstractmethod
def
train
(
self
):
pass
@
abstractmethod
def
run_step
(
self
):
pass
def
trigger_epoch
(
self
):
self
.
global_step
+=
self
.
config
.
step_per_epoch
self
.
_trigger_epoch
()
self
.
config
.
callbacks
.
trigger_epoch
()
self
.
summary_writer
.
flush
()
logger
.
stat_holder
.
finalize
()
@
abstractmethod
def
_trigger_epoch
(
self
):
pass
def
_init_summary
(
self
):
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"Please use logger.set_logger_dir at the beginning of your script."
)
self
.
summary_writer
=
tf
.
train
.
SummaryWriter
(
logger
.
LOG_DIR
,
graph_def
=
self
.
sess
.
graph_def
)
logger
.
writer
=
self
.
summary_writer
self
.
summary_op
=
tf
.
merge_all_summaries
()
# create an empty StatHolder
logger
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
,
[])
def
_process_summary
(
self
,
summary_str
):
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
# TODO move to subclasses
logger
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
def
main_loop
(
self
):
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
try
:
self
.
_init_summary
()
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
callbacks
.
before_train
()
tf
.
get_default_graph
()
.
finalize
()
for
epoch
in
xrange
(
1
,
self
.
config
.
max_epoch
):
with
timed_operation
(
'Epoch {}, global_step={}'
.
format
(
epoch
,
self
.
global_step
+
self
.
config
.
step_per_epoch
)):
for
step
in
tqdm
.
trange
(
self
.
config
.
step_per_epoch
,
leave
=
True
,
mininterval
=
0.5
,
dynamic_ncols
=
True
,
ascii
=
True
):
if
self
.
coord
.
should_stop
():
return
self
.
run_step
()
callbacks
.
trigger_step
()
self
.
trigger_epoch
()
except
(
KeyboardInterrupt
,
Exception
):
raise
finally
:
self
.
coord
.
request_stop
()
# Do I need to run queue.close?
callbacks
.
after_train
()
self
.
summary_writer
.
close
()
self
.
sess
.
close
()
def
init_session_and_coord
(
self
):
describe_model
()
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
# start training:
self
.
coord
=
tf
.
train
.
Coordinator
()
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
tensorpack/train/config.py
0 → 100644
View file @
e9a6a5af
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: config.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
..callbacks
import
Callbacks
from
..models
import
ModelDesc
from
..utils
import
*
from
..dataflow
import
DataFlow
__all__
=
[
'TrainConfig'
]
class
TrainConfig
(
object
):
"""
Config for training a model with a single loss
"""
def
__init__
(
self
,
**
kwargs
):
"""
Args:
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
optimizer: a tf.train.Optimizer instance defining the optimizer for trainig.
callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver
session_config: a tf.ConfigProto instance to instantiate the
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session.
model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
dataset
=
kwargs
.
pop
(
'dataset'
)
assert_type
(
self
.
dataset
,
DataFlow
)
self
.
optimizer
=
kwargs
.
pop
(
'optimizer'
)
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
self
.
callbacks
=
kwargs
.
pop
(
'callbacks'
)
assert_type
(
self
.
callbacks
,
Callbacks
)
self
.
model
=
kwargs
.
pop
(
'model'
)
assert_type
(
self
.
model
,
ModelDesc
)
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
NewSession
())
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
step_per_epoch
=
int
(
kwargs
.
pop
(
'step_per_epoch'
,
self
.
dataset
.
size
()))
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
100
))
assert
self
.
step_per_epoch
>
0
and
self
.
max_epoch
>
0
self
.
nr_tower
=
int
(
kwargs
.
pop
(
'nr_tower'
,
1
))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
tensorpack/train.py
→
tensorpack/train
/train
.py
View file @
e9a6a5af
...
@@ -4,65 +4,16 @@
...
@@ -4,65 +4,16 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
itertools
import
count
import
copy
import
copy
import
argparse
import
re
import
re
import
tqdm
from
abc
import
ABCMeta
from
.models
import
ModelDesc
from
.base
import
Trainer
from
.dataflow.common
import
RepeatedData
from
..dataflow.common
import
RepeatedData
from
.utils
import
*
from
..utils
import
*
from
.utils.concurrency
import
EnqueueThread
from
..utils.concurrency
import
EnqueueThread
from
.callbacks
import
*
from
..utils.summary
import
summary_moving_average
from
.utils.summary
import
summary_moving_average
from
.utils.modelutils
import
describe_model
from
.utils
import
logger
from
.dataflow
import
DataFlow
class
TrainConfig
(
object
):
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
"""
Config for training a model with a single loss
"""
def
__init__
(
self
,
**
kwargs
):
"""
Args:
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
optimizer: a tf.train.Optimizer instance defining the optimizer for trainig.
callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver
session_config: a tf.ConfigProto instance to instantiate the
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session.
model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
dataset
=
kwargs
.
pop
(
'dataset'
)
assert_type
(
self
.
dataset
,
DataFlow
)
self
.
optimizer
=
kwargs
.
pop
(
'optimizer'
)
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
self
.
callbacks
=
kwargs
.
pop
(
'callbacks'
)
assert_type
(
self
.
callbacks
,
Callbacks
)
self
.
model
=
kwargs
.
pop
(
'model'
)
assert_type
(
self
.
model
,
ModelDesc
)
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
NewSession
())
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
step_per_epoch
=
int
(
kwargs
.
pop
(
'step_per_epoch'
,
self
.
dataset
.
size
()))
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
100
))
assert
self
.
step_per_epoch
>
0
and
self
.
max_epoch
>
0
self
.
nr_tower
=
int
(
kwargs
.
pop
(
'nr_tower'
,
1
))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
summary_grads
(
grads
):
def
summary_grads
(
grads
):
for
grad
,
var
in
grads
:
for
grad
,
var
in
grads
:
...
@@ -88,95 +39,6 @@ def scale_grads(grads, multiplier):
...
@@ -88,95 +39,6 @@ def scale_grads(grads, multiplier):
ret
.
append
((
grad
,
var
))
ret
.
append
((
grad
,
var
))
return
ret
return
ret
class
Trainer
(
object
):
__metaclass__
=
ABCMeta
def
__init__
(
self
,
config
):
"""
Config: a `TrainConfig` instance
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
tf
.
add_to_collection
(
MODEL_KEY
,
config
.
model
)
@
abstractmethod
def
train
(
self
):
pass
@
abstractmethod
def
run_step
(
self
):
pass
def
trigger_epoch
(
self
):
self
.
global_step
+=
self
.
config
.
step_per_epoch
self
.
_trigger_epoch
()
self
.
config
.
callbacks
.
trigger_epoch
()
self
.
summary_writer
.
flush
()
logger
.
stat_holder
.
finalize
()
@
abstractmethod
def
_trigger_epoch
(
self
):
pass
def
_init_summary
(
self
):
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"Please use logger.set_logger_dir at the beginning of your script."
)
self
.
summary_writer
=
tf
.
train
.
SummaryWriter
(
logger
.
LOG_DIR
,
graph_def
=
self
.
sess
.
graph_def
)
logger
.
writer
=
self
.
summary_writer
self
.
summary_op
=
tf
.
merge_all_summaries
()
# create an empty StatHolder
logger
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
,
[])
def
_process_summary
(
self
,
summary_str
):
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
# TODO move to subclasses
logger
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
def
main_loop
(
self
):
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
try
:
self
.
_init_summary
()
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
callbacks
.
before_train
()
tf
.
get_default_graph
()
.
finalize
()
for
epoch
in
xrange
(
1
,
self
.
config
.
max_epoch
):
with
timed_operation
(
'Epoch {}, global_step={}'
.
format
(
epoch
,
self
.
global_step
+
self
.
config
.
step_per_epoch
)):
for
step
in
tqdm
.
trange
(
self
.
config
.
step_per_epoch
,
leave
=
True
,
mininterval
=
0.5
,
dynamic_ncols
=
True
,
ascii
=
True
):
if
self
.
coord
.
should_stop
():
return
self
.
run_step
()
callbacks
.
trigger_step
()
self
.
trigger_epoch
()
except
(
KeyboardInterrupt
,
Exception
):
raise
finally
:
self
.
coord
.
request_stop
()
# Do I need to run queue.close?
callbacks
.
after_train
()
self
.
summary_writer
.
close
()
self
.
sess
.
close
()
def
init_session_and_coord
(
self
):
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
config
.
session_init
.
init
(
self
.
sess
)
# start training:
self
.
coord
=
tf
.
train
.
Coordinator
()
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
def
run_step
(
self
):
def
run_step
(
self
):
...
@@ -200,7 +62,6 @@ class SimpleTrainer(Trainer):
...
@@ -200,7 +62,6 @@ class SimpleTrainer(Trainer):
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
avg_maintain_op
)
avg_maintain_op
)
describe_model
()
self
.
init_session_and_coord
()
self
.
init_session_and_coord
()
# create an infinte data producer
# create an infinte data producer
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
...
@@ -280,7 +141,6 @@ class QueueInputTrainer(Trainer):
...
@@ -280,7 +141,6 @@ class QueueInputTrainer(Trainer):
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
avg_maintain_op
)
avg_maintain_op
)
describe_model
()
self
.
init_session_and_coord
()
self
.
init_session_and_coord
()
# create a thread that keeps filling the queue
# create a thread that keeps filling the queue
...
@@ -299,11 +159,5 @@ class QueueInputTrainer(Trainer):
...
@@ -299,11 +159,5 @@ class QueueInputTrainer(Trainer):
def
start_train
(
config
):
def
start_train
(
config
):
#if config.model.get_input_queue() is not None:
## XXX get_input_queue is called twice
#tr = QueueInputTrainer()
#else:
#tr = SimpleTrainer()
tr
=
SimpleTrainer
(
config
)
tr
=
SimpleTrainer
(
config
)
#tr = QueueInputTrainer(config)
tr
.
train
()
tr
.
train
()
tensorpack/utils/concurrency.py
View file @
e9a6a5af
...
@@ -7,7 +7,6 @@ import threading
...
@@ -7,7 +7,6 @@ import threading
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.utils
import
expand_dim_if_necessary
from
.naming
import
*
from
.naming
import
*
from
.
import
logger
from
.
import
logger
...
@@ -44,7 +43,6 @@ class EnqueueThread(threading.Thread):
...
@@ -44,7 +43,6 @@ class EnqueueThread(threading.Thread):
return
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
self
.
sess
.
run
([
self
.
op
],
feed_dict
=
feed
)
self
.
sess
.
run
([
self
.
op
],
feed_dict
=
feed
)
#print '\nExauhsted!!!'
except
tf
.
errors
.
CancelledError
as
e
:
except
tf
.
errors
.
CancelledError
as
e
:
pass
pass
except
Exception
:
except
Exception
:
...
...
tensorpack/utils/utils.py
View file @
e9a6a5af
...
@@ -4,19 +4,19 @@
...
@@ -4,19 +4,19 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
os
import
os
def
expand_dim_if_necessary
(
var
,
dp
):
#
def expand_dim_if_necessary(var, dp):
"""
#
"""
Args:
#
Args:
var: a tensor
#
var: a tensor
dp: a numpy array
#
dp: a numpy array
Return a reshaped version of dp, if that makes it match the valid dimension of var
#
Return a reshaped version of dp, if that makes it match the valid dimension of var
"""
#
"""
shape
=
var
.
get_shape
()
.
as_list
()
#
shape = var.get_shape().as_list()
valid_shape
=
[
k
for
k
in
shape
if
k
]
#
valid_shape = [k for k in shape if k]
if
dp
.
shape
==
tuple
(
valid_shape
):
#
if dp.shape == tuple(valid_shape):
new_shape
=
[
k
if
k
else
1
for
k
in
shape
]
#
new_shape = [k if k else 1 for k in shape]
dp
=
dp
.
reshape
(
new_shape
)
#
dp = dp.reshape(new_shape)
return
dp
#
return dp
def
mkdir_p
(
dirname
):
def
mkdir_p
(
dirname
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment