Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
cc89b105
Commit
cc89b105
authored
Feb 18, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
EMA callback don't create variables itself. add old SaverRestore to be fast
parent
20d1af11
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
141 additions
and
56 deletions
+141
-56
examples/GAN/WGAN-CelebA.py
examples/GAN/WGAN-CelebA.py
+4
-0
scripts/ls-checkpoint.py
scripts/ls-checkpoint.py
+14
-0
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+7
-19
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+82
-26
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+22
-3
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+1
-0
tensorpack/train/base.py
tensorpack/train/base.py
+3
-1
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+6
-5
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+2
-2
No files found.
examples/GAN/WGAN-CelebA.py
View file @
cc89b105
...
@@ -68,6 +68,10 @@ def get_config():
...
@@ -68,6 +68,10 @@ def get_config():
class
WGANTrainer
(
FeedfreeTrainerBase
):
class
WGANTrainer
(
FeedfreeTrainerBase
):
""" A new trainer which runs two optimization ops with 5:1 ratio.
This is to be consistent with the original code, but I found just
running them 1:1 (i.e. using the existing GANTrainer) also works well.
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
_input_method
=
QueueInput
(
config
.
dataflow
)
self
.
_input_method
=
QueueInput
(
config
.
dataflow
)
super
(
WGANTrainer
,
self
)
.
__init__
(
config
)
super
(
WGANTrainer
,
self
)
.
__init__
(
config
)
...
...
scripts/ls-checkpoint.py
0 → 100755
View file @
cc89b105
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ls-checkpoint.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
sys
import
pprint
from
tensorpack.tfutils.varmanip
import
get_checkpoint_path
path
=
get_checkpoint_path
(
sys
.
argv
[
1
])
reader
=
tf
.
train
.
NewCheckpointReader
(
path
)
pprint
.
pprint
(
reader
.
get_variable_to_shape_map
())
tensorpack/callbacks/summary.py
View file @
cc89b105
...
@@ -4,10 +4,8 @@
...
@@ -4,10 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
re
from
..utils.naming
import
MOVING_SUMMARY_VARS_KEY
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
..tfutils.common
import
get_global_step_var
from
.base
import
Callback
from
.base
import
Callback
__all__
=
[
'MovingAverageSummary'
]
__all__
=
[
'MovingAverageSummary'
]
...
@@ -17,28 +15,18 @@ class MovingAverageSummary(Callback):
...
@@ -17,28 +15,18 @@ class MovingAverageSummary(Callback):
""" Maintain the moving average of the tensors
""" Maintain the moving average of the tensors
in every step, and summarize them. Enabled by default.
in every step, and summarize them. Enabled by default.
"""
"""
def
__init__
(
self
,
collection
=
MOVING_SUMMARY_
VARS_KEY
,
decay
=
0.95
):
def
__init__
(
self
,
collection
=
MOVING_SUMMARY_
OPS_KEY
):
"""
"""
Args:
Args:
collection(str): the collection of
tensors to summarize. The
collection(str): the collection of
EMA-maintaining ops.
default would work with :func:`add_moving_summary`.
The default would work with :func:`add_moving_summary()`,
decay(float): the decay of the moving average
.
but you can use some others
.
"""
"""
self
.
_collection
=
collection
self
.
_collection
=
collection
self
.
_decay
=
decay
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
tensors
=
set
(
tf
.
get_collection
(
self
.
_collection
))
ops
=
tf
.
get_collection
(
self
.
_collection
)
self
.
ema_op
=
tf
.
group
(
*
ops
,
name
=
'summary_moving_averages'
)
# TODO will produce tower0/xxx. not elegant
with
tf
.
name_scope
(
None
):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
self
.
_decay
,
num_updates
=
get_global_step_var
(),
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
tensors
)
for
idx
,
c
in
enumerate
(
tensors
):
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
self
.
ema_op
=
avg_maintain_op
def
_extra_fetches
(
self
):
def
_extra_fetches
(
self
):
return
[
self
.
ema_op
]
return
[
self
.
ema_op
]
tensorpack/tfutils/sessinit.py
View file @
cc89b105
...
@@ -3,22 +3,20 @@
...
@@ -3,22 +3,20 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
os
import
os
from
abc
import
abstractmethod
,
ABCMeta
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
import
six
from
..utils
import
logger
,
PREDICT_TOWER
from
..utils
import
logger
from
.common
import
get_op_tensor_name
from
.common
import
get_op_tensor_name
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
is_training_name
,
get_checkpoint_path
)
is_training_name
,
get_checkpoint_path
)
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'ParamRestore'
,
'ChainInit'
,
'ParamRestore'
,
'ChainInit'
,
'JustCurrentSession'
,
'get_model_loader'
]
'JustCurrentSession'
,
'get_model_loader'
]
@
six
.
add_metaclass
(
ABCMeta
)
class
SessionInit
(
object
):
class
SessionInit
(
object
):
""" Base class for utilities to initialize a session. """
""" Base class for utilities to initialize a session. """
def
init
(
self
,
sess
):
def
init
(
self
,
sess
):
...
@@ -30,23 +28,31 @@ class SessionInit(object):
...
@@ -30,23 +28,31 @@ class SessionInit(object):
"""
"""
self
.
_init
(
sess
)
self
.
_init
(
sess
)
@
abstractmethod
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
self
.
_setup_graph
()
self
.
_run_init
(
sess
)
def
_setup_graph
(
self
):
pass
def
_run_init
(
self
,
sess
):
pass
pass
class
JustCurrentSession
(
SessionInit
):
class
JustCurrentSession
(
SessionInit
):
""" This is a no-op placeholder"""
""" This is a no-op placeholder"""
def
_init
(
self
,
sess
):
pass
pass
class
NewSession
(
SessionInit
):
class
NewSession
(
SessionInit
):
"""
"""
Initialize global variables by their initializer.
Initialize global variables by their initializer.
"""
"""
def
_init
(
self
,
sess
):
def
_setup_graph
(
self
):
sess
.
run
(
tf
.
global_variables_initializer
())
self
.
op
=
tf
.
global_variables_initializer
()
def
_run_init
(
self
,
sess
):
sess
.
run
(
self
.
op
)
class
CheckpointReaderAdapter
(
object
):
class
CheckpointReaderAdapter
(
object
):
...
@@ -58,7 +64,7 @@ class CheckpointReaderAdapter(object):
...
@@ -58,7 +64,7 @@ class CheckpointReaderAdapter(object):
self
.
_reader
=
reader
self
.
_reader
=
reader
m
=
self
.
_reader
.
get_variable_to_shape_map
()
m
=
self
.
_reader
.
get_variable_to_shape_map
()
self
.
_map
=
{
k
if
k
.
endswith
(
':0'
)
else
k
+
':0'
:
v
self
.
_map
=
{
k
if
k
.
endswith
(
':0'
)
else
k
+
':0'
:
v
for
k
,
v
in
m
.
iteritems
(
)}
for
k
,
v
in
six
.
iteritems
(
m
)}
def
get_variable_to_shape_map
(
self
):
def
get_variable_to_shape_map
(
self
):
return
self
.
_map
return
self
.
_map
...
@@ -74,23 +80,77 @@ class CheckpointReaderAdapter(object):
...
@@ -74,23 +80,77 @@ class CheckpointReaderAdapter(object):
def
has_tensor
(
self
,
name
):
def
has_tensor
(
self
,
name
):
return
name
in
self
.
_map
return
name
in
self
.
_map
# some checkpoint might not have ':0'
def
get_real_name
(
self
,
name
):
if
self
.
_reader
.
has_tensor
(
name
):
return
name
assert
self
.
has_tensor
(
name
)
return
name
[:
-
2
]
class
SaverRestore
(
SessionInit
):
class
SaverRestore
(
SessionInit
):
"""
"""
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
"""
"""
def
__init__
(
self
,
model_path
,
prefix
=
None
):
def
__init__
(
self
,
model_path
,
prefix
=
None
):
"""
"""
Args:
Args:
model_path (str):
path to the model
(model-xxxx) or a ``checkpoint`` file.
model_path (str):
a model name
(model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint
"""
"""
model_path
=
get_checkpoint_path
(
model_path
)
model_path
=
get_checkpoint_path
(
model_path
)
self
.
path
=
model_path
self
.
path
=
model_path
self
.
prefix
=
prefix
self
.
prefix
=
prefix
def
_init
(
self
,
sess
):
def
_setup_graph
(
self
):
dic
=
self
.
_get_restore_dict
()
self
.
saver
=
tf
.
train
.
Saver
(
var_list
=
dic
,
name
=
str
(
id
(
dic
)))
def
_run_init
(
self
,
sess
):
logger
.
info
(
"Restoring checkpoint from {} ..."
.
format
(
self
.
path
))
self
.
saver
.
restore
(
sess
,
self
.
path
)
@
staticmethod
def
_read_checkpoint_vars
(
model_path
):
""" return a set of strings """
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
reader
=
CheckpointReaderAdapter
(
reader
)
# use an adapter to standardize the name
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
return
reader
,
set
(
ckpt_vars
)
def
_get_restore_dict
(
self
):
reader
,
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
graph_vars
=
tf
.
global_variables
()
var_dict
=
{}
chkpt_vars_used
=
set
()
for
v
in
graph_vars
:
name
=
get_savename_from_varname
(
v
.
name
,
varname_prefix
=
self
.
prefix
)
if
reader
.
has_tensor
(
name
):
ckpt_name
=
reader
.
get_real_name
(
name
)
assert
ckpt_name
not
in
var_dict
,
"Restore conflict: {} and {}"
.
format
(
v
.
name
,
var_dict
[
ckpt_name
]
.
name
)
var_dict
[
ckpt_name
]
=
v
chkpt_vars_used
.
add
(
name
)
else
:
vname
=
v
.
op
.
name
if
not
is_training_name
(
vname
):
logger
.
warn
(
"Variable {} in the graph not found in checkpoint!"
.
format
(
vname
))
if
len
(
chkpt_vars_used
)
<
len
(
chkpt_vars
):
unused
=
chkpt_vars
-
chkpt_vars_used
for
name
in
sorted
(
unused
):
if
not
is_training_name
(
name
):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
return
var_dict
class
SaverRestoreRelaxed
(
SaverRestore
):
""" Same as :class:`SaverRestore`, but has more relaxed constraints.
It allows upcasting certain variables, or reshape certain
variables when there is a mismatch that can be fixed.
Another advantage is that it doesn't add any new ops to the graph.
But it is also slower than :class:`SaverRestore`.
"""
def
_run_init
(
self
,
sess
):
logger
.
info
(
logger
.
info
(
"Restoring checkpoint from {} ..."
.
format
(
self
.
path
))
"Restoring checkpoint from {} ..."
.
format
(
self
.
path
))
reader
,
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
reader
,
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
...
@@ -114,18 +174,6 @@ class SaverRestore(SessionInit):
...
@@ -114,18 +174,6 @@ class SaverRestore(SessionInit):
if
not
is_training_name
(
name
):
if
not
is_training_name
(
name
):
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
logger
.
warn
(
"Variable {} in checkpoint not found in the graph!"
.
format
(
name
))
@
staticmethod
def
_read_checkpoint_vars
(
model_path
):
""" return a set of strings """
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
reader
=
CheckpointReaderAdapter
(
reader
)
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
for
v
in
ckpt_vars
:
if
v
.
startswith
(
PREDICT_TOWER
):
logger
.
error
(
"Found {} in checkpoint. "
"But anything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
return
reader
,
set
(
ckpt_vars
)
class
ParamRestore
(
SessionInit
):
class
ParamRestore
(
SessionInit
):
"""
"""
...
@@ -140,7 +188,7 @@ class ParamRestore(SessionInit):
...
@@ -140,7 +188,7 @@ class ParamRestore(SessionInit):
# use varname (with :0) for consistency
# use varname (with :0) for consistency
self
.
prms
=
{
get_op_tensor_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
self
.
prms
=
{
get_op_tensor_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
def
_init
(
self
,
sess
):
def
_
run_
init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
# TODO
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
# TODO
variable_names
=
set
([
k
.
name
for
k
in
variables
])
variable_names
=
set
([
k
.
name
for
k
in
variables
])
...
@@ -182,6 +230,14 @@ class ChainInit(SessionInit):
...
@@ -182,6 +230,14 @@ class ChainInit(SessionInit):
for
i
in
self
.
inits
:
for
i
in
self
.
inits
:
i
.
init
(
sess
)
i
.
init
(
sess
)
def
_setup_graph
(
self
):
for
i
in
self
.
inits
:
i
.
_setup_graph
()
def
_run_init
(
self
,
sess
):
for
i
in
self
.
inits
:
i
.
_run_init
(
sess
)
def
get_model_loader
(
filename
):
def
get_model_loader
(
filename
):
"""
"""
...
...
tensorpack/tfutils/summary.py
View file @
cc89b105
...
@@ -7,9 +7,10 @@ import tensorflow as tf
...
@@ -7,9 +7,10 @@ import tensorflow as tf
import
re
import
re
from
..utils
import
log_deprecated
from
..utils
import
log_deprecated
from
..utils.naming
import
MOVING_SUMMARY_
VAR
S_KEY
from
..utils.naming
import
MOVING_SUMMARY_
OP
S_KEY
from
.tower
import
get_current_tower_context
from
.tower
import
get_current_tower_context
from
.symbolic_functions
import
rms
from
.symbolic_functions
import
rms
from
.common
import
get_global_step_var
__all__
=
[
'create_scalar_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
__all__
=
[
'create_scalar_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
'add_moving_summary'
]
'add_moving_summary'
]
...
@@ -98,13 +99,21 @@ def add_param_summary(*summary_lists):
...
@@ -98,13 +99,21 @@ def add_param_summary(*summary_lists):
perform
(
p
,
act
)
perform
(
p
,
act
)
def
add_moving_summary
(
v
,
*
args
):
def
add_moving_summary
(
v
,
*
args
,
**
kwargs
):
"""
"""
Args:
Args:
v (tf.Tensor or list): tensor or list of tensors to summary. Must have
v (tf.Tensor or list): tensor or list of tensors to summary. Must have
scalar type.
scalar type.
args: tensors to summary (support positional arguments)
args: tensors to summary (support positional arguments)
decay (float): the decay rate. Defaults to 0.95.
collection (str): the name of the collection to add EMA-maintaining ops.
The default will work together with the default
:class:`MovingAverageSummary` callback.
"""
"""
decay
=
kwargs
.
pop
(
'decay'
,
0.95
)
coll
=
kwargs
.
pop
(
'collection'
,
MOVING_SUMMARY_OPS_KEY
)
assert
len
(
kwargs
)
==
0
,
"Unknown arguments: "
+
str
(
kwargs
)
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
if
ctx
is
not
None
and
not
ctx
.
is_main_training_tower
:
if
ctx
is
not
None
and
not
ctx
.
is_main_training_tower
:
return
return
...
@@ -112,5 +121,15 @@ def add_moving_summary(v, *args):
...
@@ -112,5 +121,15 @@ def add_moving_summary(v, *args):
v
=
[
v
]
v
=
[
v
]
v
.
extend
(
args
)
v
.
extend
(
args
)
for
x
in
v
:
for
x
in
v
:
assert
isinstance
(
x
,
tf
.
Tensor
),
x
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
assert
x
.
get_shape
()
.
ndims
==
0
,
x
.
get_shape
()
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
x
)
# TODO will produce tower0/xxx?
with
tf
.
name_scope
(
None
):
averager
=
tf
.
train
.
ExponentialMovingAverage
(
decay
,
num_updates
=
get_global_step_var
(),
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
v
)
for
c
in
v
:
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
tf
.
add_to_collection
(
coll
,
avg_maintain_op
)
tensorpack/tfutils/varmanip.py
View file @
cc89b105
...
@@ -147,6 +147,7 @@ def get_checkpoint_path(model_path):
...
@@ -147,6 +147,7 @@ def get_checkpoint_path(model_path):
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
if
os
.
path
.
basename
(
model_path
)
==
model_path
:
model_path
=
os
.
path
.
join
(
'.'
,
model_path
)
# avoid #4921 and #6142
model_path
=
os
.
path
.
join
(
'.'
,
model_path
)
# avoid #4921 and #6142
if
os
.
path
.
basename
(
model_path
)
==
'checkpoint'
:
if
os
.
path
.
basename
(
model_path
)
==
'checkpoint'
:
assert
os
.
path
.
isfile
(
model_path
),
model_path
model_path
=
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
model_path
))
model_path
=
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
model_path
))
# to be consistent with either v1 or v2
# to be consistent with either v1 or v2
...
...
tensorpack/train/base.py
View file @
cc89b105
...
@@ -138,6 +138,8 @@ class Trainer(object):
...
@@ -138,6 +138,8 @@ class Trainer(object):
# create an empty StatHolder
# create an empty StatHolder
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
self
.
config
.
session_init
.
_setup_graph
()
def
after_init
(
_
,
__
):
def
after_init
(
_
,
__
):
logger
.
info
(
"Graph variables initialized."
)
logger
.
info
(
"Graph variables initialized."
)
scaffold
=
tf
.
train
.
Scaffold
(
scaffold
=
tf
.
train
.
Scaffold
(
...
@@ -149,7 +151,7 @@ class Trainer(object):
...
@@ -149,7 +151,7 @@ class Trainer(object):
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
hooks
=
None
)
hooks
=
None
)
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
self
.
config
.
session_init
.
init
(
self
.
sess
)
self
.
config
.
session_init
.
_run_
init
(
self
.
sess
)
@
abstractmethod
@
abstractmethod
def
_setup
(
self
):
def
_setup
(
self
):
...
...
tensorpack/utils/fs.py
View file @
cc89b105
...
@@ -29,18 +29,19 @@ def mkdir_p(dirname):
...
@@ -29,18 +29,19 @@ def mkdir_p(dirname):
raise
e
raise
e
def
download
(
url
,
dir
):
def
download
(
url
,
dir
,
filename
=
None
):
"""
"""
Download URL to a directory. Will figure out the filename automatically
Download URL to a directory. Will figure out the filename automatically
from URL.
from URL.
"""
"""
mkdir_p
(
dir
)
mkdir_p
(
dir
)
fname
=
url
.
split
(
'/'
)[
-
1
]
if
filename
is
None
:
fpath
=
os
.
path
.
join
(
dir
,
fname
)
filename
=
url
.
split
(
'/'
)[
-
1
]
fpath
=
os
.
path
.
join
(
dir
,
filename
)
def
_progress
(
count
,
block_size
,
total_size
):
def
_progress
(
count
,
block_size
,
total_size
):
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
(
fname
,
(
f
ile
name
,
min
(
float
(
count
*
block_size
)
/
total_size
,
min
(
float
(
count
*
block_size
)
/
total_size
,
1.0
)
*
100.0
))
1.0
)
*
100.0
))
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
...
@@ -54,7 +55,7 @@ def download(url, dir):
...
@@ -54,7 +55,7 @@ def download(url, dir):
assert
size
>
0
,
"Download an empty file!"
assert
size
>
0
,
"Download an empty file!"
sys
.
stdout
.
write
(
'
\n
'
)
sys
.
stdout
.
write
(
'
\n
'
)
# TODO human-readable size
# TODO human-readable size
print
(
'Succesfully downloaded '
+
fname
+
" "
+
str
(
size
)
+
' bytes.'
)
print
(
'Succesfully downloaded '
+
f
ile
name
+
" "
+
str
(
size
)
+
' bytes.'
)
return
fpath
return
fpath
...
...
tensorpack/utils/naming.py
View file @
cc89b105
...
@@ -17,12 +17,12 @@ LOCAL_STEP_VAR_NAME = 'local_step:0'
...
@@ -17,12 +17,12 @@ LOCAL_STEP_VAR_NAME = 'local_step:0'
PREDICT_TOWER
=
'towerp'
PREDICT_TOWER
=
'towerp'
# extra variables to summarize during training in a moving-average way
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_
VARS_KEY
=
'MOVING_SUMMARY_VARIABLE
S'
MOVING_SUMMARY_
OPS_KEY
=
'MOVING_SUMMARY_OP
S'
# metainfo for input tensors
# metainfo for input tensors
INPUTS_KEY
=
'INPUTS_METAINFO'
INPUTS_KEY
=
'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_
VAR
S_KEY
]
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_
OP
S_KEY
]
# export all upper case variables
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
all_local_names
=
locals
()
.
keys
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment