Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bbb47815
Commit
bbb47815
authored
Jan 24, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
get_global_step -> get_global_step_value to avoid confusion with tf.train.get_global_step
parent
c7021a87
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
15 additions
and
16 deletions
+15
-16
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+1
-2
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+2
-2
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+5
-5
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+1
-2
tensorpack/train/base.py
tensorpack/train/base.py
+5
-5
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+1
-0
No files found.
tensorpack/callbacks/saver.py
View file @
bbb47815
...
...
@@ -9,7 +9,6 @@ import shutil
from
.base
import
Callback
from
..utils
import
logger
from
..tfutils.varmanip
import
get_savename_from_varname
from
..tfutils
import
get_global_step
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
...
...
@@ -76,7 +75,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
get_global_step
(),
global_step
=
tf
.
train
.
get_global_step
(),
write_meta_graph
=
False
)
logger
.
info
(
"Model saved to
%
s"
%
tf
.
train
.
get_checkpoint_state
(
self
.
checkpoint_dir
)
.
model_checkpoint_path
)
except
(
OSError
,
IOError
):
# disk error sometimes.. just ignore it
...
...
tensorpack/callbacks/stats.py
View file @
bbb47815
...
...
@@ -8,7 +8,7 @@ import json
from
.base
import
Callback
from
..utils
import
logger
from
..tfutils.common
import
get_global_step
from
..tfutils.common
import
get_global_step
_value
__all__
=
[
'StatHolder'
,
'StatPrinter'
,
'SendStat'
]
...
...
@@ -134,7 +134,7 @@ class StatPrinter(Callback):
def
_trigger_epoch
(
self
):
# by default, add this two stat
self
.
_stat_holder
.
add_stat
(
'global_step'
,
get_global_step
())
self
.
_stat_holder
.
add_stat
(
'global_step'
,
get_global_step
_value
())
self
.
_stat_holder
.
finalize
()
self
.
_stat_holder
.
add_stat
(
'epoch_num'
,
self
.
epoch_num
+
1
)
...
...
tensorpack/tfutils/common.py
View file @
bbb47815
...
...
@@ -10,7 +10,7 @@ import six
from
contextlib
import
contextmanager
__all__
=
[
'get_default_sess_config'
,
'get_global_step'
,
'get_global_step
_value
'
,
'get_global_step_var'
,
'get_op_tensor_name'
,
'get_tensors_by_names'
,
...
...
@@ -56,16 +56,16 @@ def get_global_step_var():
assert
scope
.
name
==
''
,
\
"Creating global_step_var under a variable scope would cause problems!"
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
var
=
tf
.
get_variable
(
GLOBAL_STEP_OP_NAME
,
shape
=
[],
initializer
=
tf
.
constant_initializer
(
dtype
=
tf
.
int32
)
,
var
=
tf
.
get_variable
(
GLOBAL_STEP_OP_NAME
,
initializer
=
0
,
trainable
=
False
,
dtype
=
tf
.
int32
)
return
var
def
get_global_step
():
def
get_global_step
_value
():
"""
Returns:
floa
t: global_step value in current graph and session"""
in
t: global_step value in current graph and session"""
return
tf
.
train
.
global_step
(
tf
.
get_default_session
(),
get_global_step_var
())
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
bbb47815
...
...
@@ -140,8 +140,7 @@ def get_scalar_var(name, init_value, summary=False, trainable=False):
Returns:
tf.Variable: the variable
"""
ret
=
tf
.
get_variable
(
name
,
shape
=
[],
initializer
=
tf
.
constant_initializer
(
init_value
),
ret
=
tf
.
get_variable
(
name
,
initializer
=
init_value
,
trainable
=
trainable
)
if
summary
:
# this is recognized in callbacks.StatHolder
...
...
tensorpack/train/base.py
View file @
bbb47815
...
...
@@ -13,7 +13,7 @@ from .config import TrainConfig
from
..utils
import
logger
from
..utils.timer
import
timed_operation
from
..callbacks
import
StatHolder
from
..tfutils
import
get_global_step
,
get_global_step_var
from
..tfutils
import
get_global_step
_var
,
get_global_step_value
from
..tfutils.modelutils
import
describe_model
from
..tfutils.summary
import
create_scalar_summary
...
...
@@ -121,7 +121,7 @@ class Trainer(object):
if
val
.
tag
.
endswith
(
suffix
):
val
.
tag
=
val
.
tag
[:
-
len
(
suffix
)]
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
get_global_step
())
self
.
summary_writer
.
add_summary
(
summary
,
get_global_step
_value
())
def
add_scalar_summary
(
self
,
name
,
val
):
"""
...
...
@@ -144,7 +144,7 @@ class Trainer(object):
"""
self
.
_setup
()
describe_model
()
get_global_step_var
()
get_global_step_var
()
# ensure such var exists
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
...
...
@@ -178,12 +178,12 @@ class Trainer(object):
with
self
.
sess
.
as_default
():
try
:
callbacks
.
before_train
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
_value
()))
for
self
.
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
'Epoch {} (global_step {})'
.
format
(
self
.
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
),
self
.
epoch_num
,
get_global_step
_value
()
+
self
.
config
.
step_per_epoch
),
log_start
=
True
):
for
self
.
step_num
in
range
(
self
.
config
.
step_per_epoch
):
if
self
.
coord
.
should_stop
():
...
...
tensorpack/utils/naming.py
View file @
bbb47815
...
...
@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
# this is also the name used by tf.train.get_global_step
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment