Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a5652699
Commit
a5652699
authored
Feb 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
some clean-ups, and add an alias `hooked_sess` (#147)
parent
ccf4a5a0
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
30 additions
and
27 deletions
+30
-27
examples/GAN/WGAN-CelebA.py
examples/GAN/WGAN-CelebA.py
+2
-2
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+2
-2
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+11
-9
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+9
-11
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-0
tensorpack/train/base.py
tensorpack/train/base.py
+3
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
No files found.
examples/GAN/WGAN-CelebA.py
View file @
a5652699
...
@@ -88,8 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
...
@@ -88,8 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
def
run_step
(
self
):
def
run_step
(
self
):
for
k
in
range
(
5
):
for
k
in
range
(
5
):
self
.
monitor
ed_sess
.
run
(
self
.
d_min
)
self
.
hook
ed_sess
.
run
(
self
.
d_min
)
self
.
monitor
ed_sess
.
run
(
self
.
g_min
)
self
.
hook
ed_sess
.
run
(
self
.
g_min
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/callbacks/base.py
View file @
a5652699
...
@@ -74,10 +74,10 @@ class Callback(object):
...
@@ -74,10 +74,10 @@ class Callback(object):
Same as ``tf.train.SessionRunHook.before_run``.
Same as ``tf.train.SessionRunHook.before_run``.
"""
"""
fetches
=
self
.
_before_run
(
ctx
)
fetches
=
self
.
_before_run
(
ctx
)
if
isinstance
(
fetches
,
tf
.
train
.
SessionRunArgs
):
return
fetches
if
fetches
is
None
:
if
fetches
is
None
:
return
None
return
None
if
isinstance
(
fetches
,
tf
.
train
.
SessionRunArgs
):
return
fetches
# also support list of names
# also support list of names
assert
isinstance
(
fetches
,
list
),
fetches
assert
isinstance
(
fetches
,
list
),
fetches
...
...
tensorpack/callbacks/steps.py
View file @
a5652699
...
@@ -10,8 +10,7 @@ from six.moves import zip
...
@@ -10,8 +10,7 @@ from six.moves import zip
import
tqdm
import
tqdm
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.naming
import
(
GLOBAL_STEP_INCR_OP_NAME
,
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
LOCAL_STEP_OP_NAME
)
from
..tfutils.common
import
(
from
..tfutils.common
import
(
get_op_tensor_name
,
get_global_step_var
,
get_op_tensor_name
,
get_global_step_var
,
get_global_step_value
,
get_op_or_tensor_by_name
)
get_global_step_value
,
get_op_or_tensor_by_name
)
...
@@ -61,9 +60,10 @@ class MaintainStepCounter(Callback):
...
@@ -61,9 +60,10 @@ class MaintainStepCounter(Callback):
self
.
gs_incr_var
=
tf
.
assign_add
(
self
.
gs_incr_var
=
tf
.
assign_add
(
gs_var
,
1
,
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
name
=
GLOBAL_STEP_INCR_OP_NAME
)
tf
.
mod
(
# tf.mod(
self
.
gs_incr_var
,
self
.
trainer
.
config
.
steps_per_epoch
,
# self.gs_incr_var, self.trainer.config.steps_per_epoch,
name
=
LOCAL_STEP_OP_NAME
)
# name=LOCAL_STEP_OP_NAME)
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_var
)
def
_before_train
(
self
):
def
_before_train
(
self
):
gs_val
=
get_global_step_value
()
gs_val
=
get_global_step_value
()
...
@@ -75,7 +75,7 @@ class MaintainStepCounter(Callback):
...
@@ -75,7 +75,7 @@ class MaintainStepCounter(Callback):
# increase global_step, when trainer.local_step changed
# increase global_step, when trainer.local_step changed
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
trainer
.
local_step
self
.
_last_updated
=
self
.
trainer
.
local_step
return
[
self
.
gs_incr_var
.
op
]
return
self
.
_fetches
else
:
else
:
return
None
return
None
...
@@ -93,12 +93,14 @@ class ProgressBar(Callback):
...
@@ -93,12 +93,14 @@ class ProgressBar(Callback):
self
.
_tags
=
[
get_op_tensor_name
(
n
)[
0
]
.
split
(
"/"
)[
-
1
]
for
n
in
names
]
self
.
_tags
=
[
get_op_tensor_name
(
n
)[
0
]
.
split
(
"/"
)[
-
1
]
for
n
in
names
]
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
self
.
_last_updated
=
self
.
trainer
.
local_step
self
.
_last_updated
=
self
.
trainer
.
local_step
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
if
len
(
self
.
_names
):
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
or
None
if
self
.
_fetches
:
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
_fetches
)
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
def
_before_run
(
self
,
_
):
def
_before_run
(
self
,
_
):
...
@@ -114,7 +116,7 @@ class ProgressBar(Callback):
...
@@ -114,7 +116,7 @@ class ProgressBar(Callback):
def
_after_run
(
self
,
_
,
run_values
):
def
_after_run
(
self
,
_
,
run_values
):
res
=
run_values
.
results
res
=
run_values
.
results
if
len
(
res
)
:
if
res
:
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
res
))
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
res
))
def
_trigger_step
(
self
):
def
_trigger_step
(
self
):
...
...
tensorpack/tfutils/common.py
View file @
a5652699
...
@@ -8,16 +8,14 @@ from six.moves import map
...
@@ -8,16 +8,14 @@ from six.moves import map
from
..utils.naming
import
(
from
..utils.naming
import
(
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_OP_NAME
,
GLOBAL_STEP_OP_NAME
)
LOCAL_STEP_VAR_NAME
)
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
__all__
=
[
'get_default_sess_config'
,
__all__
=
[
'get_default_sess_config'
,
'get_global_step_value'
,
'get_global_step_value'
,
'get_global_step_var'
,
'get_global_step_var'
,
'get_local_step_var'
,
#
'get_local_step_var',
'get_op_tensor_name'
,
'get_op_tensor_name'
,
'get_tensors_by_names'
,
'get_tensors_by_names'
,
...
@@ -75,13 +73,13 @@ def get_global_step_value():
...
@@ -75,13 +73,13 @@ def get_global_step_value():
get_global_step_var
())
get_global_step_var
())
@
memoized
#
@memoized
def
get_local_step_var
():
#
def get_local_step_var():
try
:
#
try:
return
tf
.
get_default_graph
()
.
get_tensor_by_name
(
LOCAL_STEP_VAR_NAME
)
#
return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME)
except
KeyError
:
#
except KeyError:
logger
.
warn
(
"get_local_step_var() is only available to use in callbacks!"
)
#
logger.warn("get_local_step_var() is only available to use in callbacks!")
raise
#
raise
def
get_op_tensor_name
(
name
):
def
get_op_tensor_name
(
name
):
...
...
tensorpack/tfutils/summary.py
View file @
a5652699
...
@@ -129,6 +129,7 @@ def add_moving_summary(v, *args, **kwargs):
...
@@ -129,6 +129,7 @@ def add_moving_summary(v, *args, **kwargs):
decay
,
num_updates
=
get_global_step_var
(),
name
=
'EMA'
)
decay
,
num_updates
=
get_global_step_var
(),
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
v
)
avg_maintain_op
=
averager
.
apply
(
v
)
for
c
in
v
:
for
c
in
v
:
# TODO do this in the EMA callback?
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
...
...
tensorpack/train/base.py
View file @
a5652699
...
@@ -140,7 +140,9 @@ class Trainer(object):
...
@@ -140,7 +140,9 @@ class Trainer(object):
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
hooks
=
self
.
config
.
callbacks
.
get_hooks
())
hooks
=
self
.
config
.
callbacks
.
get_hooks
())
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
self
.
hooked_sess
=
self
.
monitored_sess
# just create an alias
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
# expose the underlying session also
self
.
config
.
session_init
.
_run_init
(
self
.
sess
)
self
.
config
.
session_init
.
_run_init
(
self
.
sess
)
@
abstractmethod
@
abstractmethod
...
...
tensorpack/train/feedfree.py
View file @
a5652699
...
@@ -48,7 +48,7 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -48,7 +48,7 @@ class FeedfreeTrainerBase(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run ``self.train_op``, which minimizes the cost."""
""" Simply run ``self.train_op``, which minimizes the cost."""
self
.
monitor
ed_sess
.
run
(
self
.
train_op
)
self
.
hook
ed_sess
.
run
(
self
.
train_op
)
# if not hasattr(self, 'cnt'):
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# self.cnt = 0
# else:
# else:
...
...
tensorpack/train/trainer.py
View file @
a5652699
...
@@ -87,7 +87,7 @@ class SimpleTrainer(Trainer):
...
@@ -87,7 +87,7 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_method
.
next_feed
()
feed
=
self
.
_input_method
.
next_feed
()
self
.
monitor
ed_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
self
.
hook
ed_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
def
_setup
(
self
):
def
_setup
(
self
):
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment