Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f3644ce9
Commit
f3644ce9
authored
Feb 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'dev'
parents
05f7ba8f
0f2eaeea
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
158 additions
and
149 deletions
+158
-149
examples/GAN/GAN.py
examples/GAN/GAN.py
+0
-4
examples/GAN/WGAN-CelebA.py
examples/GAN/WGAN-CelebA.py
+2
-3
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+31
-28
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+17
-27
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+47
-22
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+5
-7
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+1
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+26
-14
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-0
tensorpack/train/base.py
tensorpack/train/base.py
+10
-22
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+16
-17
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-3
No files found.
examples/GAN/GAN.py
View file @
f3644ce9
...
@@ -77,10 +77,6 @@ class GANTrainer(FeedfreeTrainerBase):
...
@@ -77,10 +77,6 @@ class GANTrainer(FeedfreeTrainerBase):
self
.
d_min
=
opt
.
minimize
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
,
name
=
'd_op'
)
self
.
d_min
=
opt
.
minimize
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
,
name
=
'd_op'
)
self
.
train_op
=
self
.
d_min
self
.
train_op
=
self
.
d_min
def
run_step
(
self
):
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
get_extra_fetches
())
return
ret
[
1
:]
class
RandomZData
(
DataFlow
):
class
RandomZData
(
DataFlow
):
def
__init__
(
self
,
shape
):
def
__init__
(
self
,
shape
):
...
...
examples/GAN/WGAN-CelebA.py
View file @
f3644ce9
...
@@ -88,9 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
...
@@ -88,9 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
def
run_step
(
self
):
def
run_step
(
self
):
for
k
in
range
(
5
):
for
k
in
range
(
5
):
self
.
sess
.
run
(
self
.
d_min
)
self
.
hooked_sess
.
run
(
self
.
d_min
)
ret
=
self
.
sess
.
run
([
self
.
g_min
]
+
self
.
get_extra_fetches
())
self
.
hooked_sess
.
run
(
self
.
g_min
)
return
ret
[
1
:]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/callbacks/base.py
View file @
f3644ce9
...
@@ -54,42 +54,44 @@ class Callback(object):
...
@@ -54,42 +54,44 @@ class Callback(object):
def
_before_train
(
self
):
def
_before_train
(
self
):
pass
pass
def
trigger_step
(
self
,
*
args
):
def
trigger_step
(
self
):
"""
"""
Callback to be triggered after every step (every backpropagation).
Callback to be triggered after every run_step.
Args:
args: a list of values corresponding to :meth:`extra_fetches`.
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
"""
self
.
_trigger_step
(
*
args
)
self
.
_trigger_step
()
def
_trigger_step
(
self
,
*
args
):
def
_trigger_step
(
self
):
pass
pass
def
extra_fetches
(
self
):
def
after_run
(
self
,
run_context
,
run_values
):
"""
self
.
_after_run
(
run_context
,
run_values
)
Returns:
list: a list of elements to be fetched in every step and
passed to :meth:`trigger_step`. Elements can be
Operations/Tensors, or names of Operations/Tensors.
This function will be called only after the graph is finalized.
def
_after_run
(
self
,
run_context
,
run_values
):
pass
This function should be a pure function (i.e. no side-effect when called)
def
before_run
(
self
,
ctx
):
"""
Same as ``tf.train.SessionRunHook.before_run``.
"""
"""
fetches
=
self
.
_extra_fetches
()
fetches
=
self
.
_before_run
(
ctx
)
if
fetches
is
None
:
return
None
if
isinstance
(
fetches
,
tf
.
train
.
SessionRunArgs
):
return
fetches
# also support list of names
assert
isinstance
(
fetches
,
list
),
fetches
ret
=
[]
ret
=
[]
for
f
in
fetches
:
for
f
in
fetches
:
if
isinstance
(
f
,
(
tf
.
Tensor
,
tf
.
Operation
)):
if
isinstance
(
f
,
(
tf
.
Tensor
,
tf
.
Operation
)):
ret
.
append
(
f
)
ret
.
append
(
f
)
else
:
else
:
# warn about speed
ret
.
append
(
get_op_or_tensor_by_name
(
f
))
ret
.
append
(
get_op_or_tensor_by_name
(
f
))
return
ret
return
tf
.
train
.
SessionRunArgs
(
fetches
=
ret
)
def
_
extra_fetches
(
self
):
def
_
before_run
(
self
,
ctx
):
return
[]
return
None
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
"""
"""
...
@@ -113,11 +115,6 @@ class Callback(object):
...
@@ -113,11 +115,6 @@ class Callback(object):
def
epoch_num
(
self
):
def
epoch_num
(
self
):
return
self
.
trainer
.
epoch_num
return
self
.
trainer
.
epoch_num
@
property
def
local_step
(
self
):
# inside trainer, we're still in the 'local_step' loop, so the number is off by 1
return
self
.
trainer
.
local_step
+
1
@
property
@
property
def
global_step
(
self
):
def
global_step
(
self
):
return
self
.
trainer
.
global_step
return
self
.
trainer
.
global_step
...
@@ -177,12 +174,18 @@ class ProxyCallback(Callback):
...
@@ -177,12 +174,18 @@ class ProxyCallback(Callback):
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
cb
.
trigger_epoch
()
self
.
cb
.
trigger_epoch
()
def
_trigger_step
(
self
,
*
args
):
def
_trigger_step
(
self
):
self
.
cb
.
trigger_step
(
*
args
)
self
.
cb
.
trigger_step
()
def
_after_train
(
self
):
def
_after_train
(
self
):
self
.
cb
.
after_train
()
self
.
cb
.
after_train
()
def
_before_run
(
self
,
ctx
):
self
.
cb
.
_before_run
(
ctx
)
def
_after_run
(
self
,
ctx
,
run_values
):
self
.
cb
.
_after_run
(
ctx
,
run_values
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"Proxy-"
+
str
(
self
.
cb
)
return
"Proxy-"
+
str
(
self
.
cb
)
...
...
tensorpack/callbacks/group.py
View file @
f3644ce9
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
collections
import
defaultdict
import
time
import
time
import
traceback
import
traceback
...
@@ -15,8 +14,18 @@ from ..utils import logger
...
@@ -15,8 +14,18 @@ from ..utils import logger
__all__
=
[
'Callbacks'
]
__all__
=
[
'Callbacks'
]
class
CallbackTimeLogger
(
object
):
class
CallbackHook
(
tf
.
train
.
SessionRunHook
):
def
__init__
(
self
,
cb
):
self
.
cb
=
cb
def
before_run
(
self
,
ctx
):
return
self
.
cb
.
before_run
(
ctx
)
def
after_run
(
self
,
ctx
,
vals
):
self
.
cb
.
after_run
(
ctx
,
vals
)
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
times
=
[]
self
.
times
=
[]
self
.
tot
=
0
self
.
tot
=
0
...
@@ -71,7 +80,6 @@ class Callbacks(Callback):
...
@@ -71,7 +80,6 @@ class Callbacks(Callback):
break
break
self
.
cbs
=
cbs
self
.
cbs
=
cbs
self
.
_extra_fetches_cache
=
None
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
with
tf
.
name_scope
(
None
):
with
tf
.
name_scope
(
None
):
...
@@ -90,30 +98,12 @@ class Callbacks(Callback):
...
@@ -90,30 +98,12 @@ class Callbacks(Callback):
except
Exception
:
except
Exception
:
traceback
.
print_exc
()
traceback
.
print_exc
()
def
_extra_fetches
(
self
):
def
get_hooks
(
self
):
if
self
.
_extra_fetches_cache
is
not
None
:
return
[
CallbackHook
(
cb
)
for
cb
in
self
.
cbs
]
return
self
.
_extra_fetches_cache
# TODO use dispatch mechanism to avoid duplication
def
trigger_step
(
self
):
self
.
_cbid_to_fetchid
=
defaultdict
(
list
)
for
cb
in
self
.
cbs
:
ret
=
[]
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
fetch
=
cb
.
extra_fetches
()
if
len
(
fetch
)
==
0
:
continue
for
f
in
fetch
:
ret
.
append
(
f
)
self
.
_cbid_to_fetchid
[
idx
]
.
append
(
len
(
ret
)
-
1
)
self
.
_extra_fetches_cache
=
ret
return
ret
def
_trigger_step
(
self
,
*
args
):
for
idx
,
cb
in
enumerate
(
self
.
cbs
):
fid
=
self
.
_cbid_to_fetchid
[
idx
]
if
len
(
fid
)
==
0
:
cb
.
trigger_step
()
cb
.
trigger_step
()
else
:
data
=
[
args
[
k
]
for
k
in
fid
]
cb
.
trigger_step
(
*
data
)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
tm
=
CallbackTimeLogger
()
tm
=
CallbackTimeLogger
()
...
...
tensorpack/callbacks/steps.py
View file @
f3644ce9
...
@@ -10,9 +10,10 @@ from six.moves import zip
...
@@ -10,9 +10,10 @@ from six.moves import zip
import
tqdm
import
tqdm
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.naming
import
(
GLOBAL_STEP_INCR_OP_NAME
,
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
LOCAL_STEP_OP_NAME
)
from
..tfutils.common
import
(
from
..tfutils.common
import
get_op_tensor_name
,
get_global_step_var
,
get_global_step_value
get_op_tensor_name
,
get_global_step_var
,
get_global_step_value
,
get_op_or_tensor_by_name
)
from
.base
import
Callback
from
.base
import
Callback
__all__
=
[
'StepTensorPrinter'
,
'MaintainStepCounter'
,
__all__
=
[
'StepTensorPrinter'
,
'MaintainStepCounter'
,
...
@@ -33,10 +34,14 @@ class StepTensorPrinter(Callback):
...
@@ -33,10 +34,14 @@ class StepTensorPrinter(Callback):
logger
.
warn
(
"Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!"
)
logger
.
warn
(
"Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!"
)
self
.
_names
=
names
self
.
_names
=
names
def
_extra_fetches
(
self
):
def
_before_train
(
self
):
return
self
.
_names
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
def
_before_run
(
self
,
_
):
return
self
.
_fetches
def
_trigger_step
(
self
,
*
args
):
def
_after_run
(
self
,
_
,
vals
):
args
=
vals
.
results
assert
len
(
args
)
==
len
(
self
.
_names
),
len
(
args
)
assert
len
(
args
)
==
len
(
self
.
_names
),
len
(
args
)
for
n
,
v
in
zip
(
self
.
_names
,
args
):
for
n
,
v
in
zip
(
self
.
_names
,
args
):
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
...
@@ -55,17 +60,24 @@ class MaintainStepCounter(Callback):
...
@@ -55,17 +60,24 @@ class MaintainStepCounter(Callback):
self
.
gs_incr_var
=
tf
.
assign_add
(
self
.
gs_incr_var
=
tf
.
assign_add
(
gs_var
,
1
,
gs_var
,
1
,
name
=
GLOBAL_STEP_INCR_OP_NAME
)
name
=
GLOBAL_STEP_INCR_OP_NAME
)
tf
.
mod
(
# tf.mod(
self
.
gs_incr_var
,
self
.
trainer
.
config
.
steps_per_epoch
,
# self.gs_incr_var, self.trainer.config.steps_per_epoch,
name
=
LOCAL_STEP_OP_NAME
)
# name=LOCAL_STEP_OP_NAME)
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
gs_incr_var
)
def
_before_train
(
self
):
def
_before_train
(
self
):
gs_val
=
get_global_step_value
()
gs_val
=
get_global_step_value
()
if
gs_val
!=
0
:
if
gs_val
!=
0
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
gs_val
))
logger
.
info
(
"Start training with global_step={}"
.
format
(
gs_val
))
self
.
_last_updated
=
self
.
trainer
.
local_step
def
_extra_fetches
(
self
):
def
_before_run
(
self
,
_
):
return
[
self
.
gs_incr_var
.
op
]
# increase global_step, when trainer.local_step changed
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
trainer
.
local_step
return
self
.
_fetches
else
:
return
None
class
ProgressBar
(
Callback
):
class
ProgressBar
(
Callback
):
...
@@ -80,21 +92,34 @@ class ProgressBar(Callback):
...
@@ -80,21 +92,34 @@ class ProgressBar(Callback):
self
.
_names
=
[
get_op_tensor_name
(
n
)[
1
]
for
n
in
names
]
self
.
_names
=
[
get_op_tensor_name
(
n
)[
1
]
for
n
in
names
]
self
.
_tags
=
[
get_op_tensor_name
(
n
)[
0
]
.
split
(
"/"
)[
-
1
]
for
n
in
names
]
self
.
_tags
=
[
get_op_tensor_name
(
n
)[
0
]
.
split
(
"/"
)[
-
1
]
for
n
in
names
]
def
_extra_fetches
(
self
):
return
self
.
_names
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_last_updated
=
self
.
trainer
.
local_step
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
if
len
(
self
.
_names
):
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
or
None
if
self
.
_fetches
:
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
_fetches
)
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
def
_trigger_step
(
self
,
*
args
):
def
_before_run
(
self
,
_
):
if
self
.
local_step
==
1
:
if
self
.
trainer
.
local_step
!=
self
.
_last_updated
:
self
.
_last_updated
=
self
.
trainer
.
local_step
if
self
.
trainer
.
local_step
==
0
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
if
len
(
self
.
_names
):
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
args
))
self
.
_bar
.
update
()
if
self
.
local_step
==
self
.
_total
:
return
self
.
_fetches
else
:
return
None
def
_after_run
(
self
,
_
,
run_values
):
res
=
run_values
.
results
if
res
:
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
res
))
def
_trigger_step
(
self
):
self
.
_bar
.
update
()
if
self
.
trainer
.
local_step
==
self
.
_total
-
1
:
self
.
_bar
.
close
()
self
.
_bar
.
close
()
tensorpack/callbacks/summary.py
View file @
f3644ce9
...
@@ -28,5 +28,5 @@ class MovingAverageSummary(Callback):
...
@@ -28,5 +28,5 @@ class MovingAverageSummary(Callback):
ops
=
tf
.
get_collection
(
self
.
_collection
)
ops
=
tf
.
get_collection
(
self
.
_collection
)
self
.
ema_op
=
tf
.
group
(
*
ops
,
name
=
'summary_moving_averages'
)
self
.
ema_op
=
tf
.
group
(
*
ops
,
name
=
'summary_moving_averages'
)
def
_
extra_fetches
(
self
):
def
_
before_run
(
self
,
_
):
return
[
self
.
ema_op
]
return
[
self
.
ema_op
]
tensorpack/callbacks/trigger.py
View file @
f3644ce9
...
@@ -31,13 +31,15 @@ class PeriodicTrigger(ProxyCallback):
...
@@ -31,13 +31,15 @@ class PeriodicTrigger(ProxyCallback):
self
.
_step_k
=
every_k_steps
self
.
_step_k
=
every_k_steps
self
.
_epoch_k
=
every_k_epochs
self
.
_epoch_k
=
every_k_epochs
def
_trigger_step
(
self
,
*
args
):
def
_trigger_step
(
self
):
if
self
.
_step_k
is
None
:
if
self
.
_step_k
is
None
:
return
return
if
self
.
local_step
%
self
.
_step_k
==
0
:
# trigger_step is triggered after run_step, so
# local_step + 1 is the number of step that have finished
if
(
self
.
trainer
.
local_step
+
1
)
%
self
.
_step_k
==
0
:
self
.
cb
.
trigger
()
self
.
cb
.
trigger
()
def
_trigger_epoch
(
self
,
*
args
):
def
_trigger_epoch
(
self
):
if
self
.
_epoch_k
is
None
:
if
self
.
_epoch_k
is
None
:
return
return
if
self
.
epoch_num
%
self
.
_epoch_k
==
0
:
if
self
.
epoch_num
%
self
.
_epoch_k
==
0
:
...
@@ -62,10 +64,6 @@ class PeriodicCallback(ProxyCallback):
...
@@ -62,10 +64,6 @@ class PeriodicCallback(ProxyCallback):
Args:
Args:
cb(Callback): the callback to be triggered periodically
cb(Callback): the callback to be triggered periodically
period(int): the period, the number of epochs for a callback to be triggered.
period(int): the period, the number of epochs for a callback to be triggered.
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
"""
"""
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
super
(
PeriodicCallback
,
self
)
.
__init__
(
cb
)
self
.
period
=
int
(
period
)
self
.
period
=
int
(
period
)
...
...
tensorpack/models/conv2d.py
View file @
f3644ce9
...
@@ -141,7 +141,7 @@ def Deconv2D(x, out_shape, kernel_shape,
...
@@ -141,7 +141,7 @@ def Deconv2D(x, out_shape, kernel_shape,
for
k
in
out_shape
:
for
k
in
out_shape
:
if
not
isinstance
(
k
,
int
):
if
not
isinstance
(
k
,
int
):
raise
ValueError
(
"[Deconv2D] out_shape {} is invalid!"
.
format
(
k
))
raise
ValueError
(
"[Deconv2D] out_shape {} is invalid!"
.
format
(
k
))
out_channel
=
out_shape
[
channel_axis
-
1
]
out_channel
=
out_shape
[
channel_axis
-
1
]
# out_shape doesn't have batch
shp3_static
=
shp3_dyn
=
out_shape
shp3_static
=
shp3_dyn
=
out_shape
filter_shape
=
kernel_shape
+
[
out_channel
,
in_channel
]
filter_shape
=
kernel_shape
+
[
out_channel
,
in_channel
]
...
...
tensorpack/tfutils/common.py
View file @
f3644ce9
...
@@ -4,19 +4,18 @@
...
@@ -4,19 +4,18 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
six.moves
import
map
from
..utils.naming
import
(
from
..utils.naming
import
(
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_VAR_NAME
,
GLOBAL_STEP_OP_NAME
,
GLOBAL_STEP_OP_NAME
)
LOCAL_STEP_VAR_NAME
)
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
__all__
=
[
'get_default_sess_config'
,
__all__
=
[
'get_default_sess_config'
,
'get_global_step_value'
,
'get_global_step_value'
,
'get_global_step_var'
,
'get_global_step_var'
,
'get_local_step_var'
,
#
'get_local_step_var',
'get_op_tensor_name'
,
'get_op_tensor_name'
,
'get_tensors_by_names'
,
'get_tensors_by_names'
,
...
@@ -74,13 +73,13 @@ def get_global_step_value():
...
@@ -74,13 +73,13 @@ def get_global_step_value():
get_global_step_var
())
get_global_step_var
())
@
memoized
#
@memoized
def
get_local_step_var
():
#
def get_local_step_var():
try
:
#
try:
return
tf
.
get_default_graph
()
.
get_tensor_by_name
(
LOCAL_STEP_VAR_NAME
)
#
return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME)
except
KeyError
:
#
except KeyError:
logger
.
warn
(
"get_local_step_var() is only available to use in callbacks!"
)
#
logger.warn("get_local_step_var() is only available to use in callbacks!")
raise
#
raise
def
get_op_tensor_name
(
name
):
def
get_op_tensor_name
(
name
):
...
@@ -116,11 +115,24 @@ def get_tensors_by_names(names):
...
@@ -116,11 +115,24 @@ def get_tensors_by_names(names):
def
get_op_or_tensor_by_name
(
name
):
def
get_op_or_tensor_by_name
(
name
):
"""
Get either tf.Operation of tf.Tensor from names.
Args:
name (list[str] or str): names of operations or tensors.
"""
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
if
len
(
name
)
>=
3
and
name
[
-
2
]
==
':'
:
return
G
.
get_tensor_by_name
(
name
)
def
f
(
n
):
if
len
(
n
)
>=
3
and
n
[
-
2
]
==
':'
:
return
G
.
get_tensor_by_name
(
n
)
else
:
return
G
.
get_operation_by_name
(
n
)
if
not
isinstance
(
name
,
list
):
return
f
(
name
)
else
:
else
:
return
G
.
get_operation_by_name
(
name
)
return
list
(
map
(
f
,
name
)
)
def
get_name_scope_name
():
def
get_name_scope_name
():
...
...
tensorpack/tfutils/summary.py
View file @
f3644ce9
...
@@ -129,6 +129,7 @@ def add_moving_summary(v, *args, **kwargs):
...
@@ -129,6 +129,7 @@ def add_moving_summary(v, *args, **kwargs):
decay
,
num_updates
=
get_global_step_var
(),
name
=
'EMA'
)
decay
,
num_updates
=
get_global_step_var
(),
name
=
'EMA'
)
avg_maintain_op
=
averager
.
apply
(
v
)
avg_maintain_op
=
averager
.
apply
(
v
)
for
c
in
v
:
for
c
in
v
:
# TODO do this in the EMA callback?
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
tf
.
summary
.
scalar
(
name
+
'-summary'
,
averager
.
average
(
c
))
...
...
tensorpack/train/base.py
View file @
f3644ce9
...
@@ -40,8 +40,8 @@ class Trainer(object):
...
@@ -40,8 +40,8 @@ class Trainer(object):
summary_writer (tf.summary.FileWriter)
summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries.
summary_op (tf.Operation): an Op which outputs all summaries.
epoch_num (int): the
current epoch number
.
epoch_num (int): the
number of epochs that have finished
.
local_step (int): the
current step number (in an epoch)
.
local_step (int): the
number of steps that have finished in the current epoch
.
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -54,7 +54,7 @@ class Trainer(object):
...
@@ -54,7 +54,7 @@ class Trainer(object):
self
.
model
=
config
.
model
self
.
model
=
config
.
model
self
.
epoch_num
=
self
.
config
.
starting_epoch
-
1
self
.
epoch_num
=
self
.
config
.
starting_epoch
-
1
self
.
local_step
=
0
self
.
local_step
=
-
1
def
train
(
self
):
def
train
(
self
):
""" Start training """
""" Start training """
...
@@ -65,15 +65,6 @@ class Trainer(object):
...
@@ -65,15 +65,6 @@ class Trainer(object):
def
run_step
(
self
):
def
run_step
(
self
):
""" Abstract method. Run one iteration. """
""" Abstract method. Run one iteration. """
def
get_extra_fetches
(
self
):
"""
Returns:
list: list of tensors/ops to fetch in each step.
This function should only get called after :meth:`setup()` has finished.
"""
return
self
.
_extra_fetches
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
"""
"""
Called after each epoch.
Called after each epoch.
...
@@ -130,7 +121,6 @@ class Trainer(object):
...
@@ -130,7 +121,6 @@ class Trainer(object):
# some final operations that might modify the graph
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks graph ..."
)
logger
.
info
(
"Setup callbacks graph ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_extra_fetches
=
self
.
config
.
callbacks
.
extra_fetches
()
logger
.
info
(
"Setup summaries ..."
)
logger
.
info
(
"Setup summaries ..."
)
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
tf
.
get_default_graph
())
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
tf
.
get_default_graph
())
...
@@ -149,8 +139,10 @@ class Trainer(object):
...
@@ -149,8 +139,10 @@ class Trainer(object):
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
hooks
=
None
)
hooks
=
self
.
config
.
callbacks
.
get_hooks
())
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
self
.
hooked_sess
=
self
.
monitored_sess
# just create an alias
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
# expose the underlying session also
self
.
config
.
session_init
.
_run_init
(
self
.
sess
)
self
.
config
.
session_init
.
_run_init
(
self
.
sess
)
@
abstractmethod
@
abstractmethod
...
@@ -162,7 +154,7 @@ class Trainer(object):
...
@@ -162,7 +154,7 @@ class Trainer(object):
try
:
try
:
return
self
.
_starting_step
+
\
return
self
.
_starting_step
+
\
self
.
config
.
steps_per_epoch
*
(
self
.
epoch_num
-
1
)
+
\
self
.
config
.
steps_per_epoch
*
(
self
.
epoch_num
-
1
)
+
\
self
.
local_step
+
1
self
.
local_step
+
1
# +1: the ongoing step
except
AttributeError
:
except
AttributeError
:
return
get_global_step_value
()
return
get_global_step_value
()
...
@@ -182,12 +174,8 @@ class Trainer(object):
...
@@ -182,12 +174,8 @@ class Trainer(object):
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
if
self
.
monitored_sess
.
should_stop
():
if
self
.
monitored_sess
.
should_stop
():
return
return
fetch_data
=
self
.
run_step
()
# implemented by subclass
self
.
run_step
()
# implemented by subclass
if
fetch_data
is
None
:
# old trainer doesn't return fetch data
callbacks
.
trigger_step
()
callbacks
.
trigger_step
()
else
:
callbacks
.
trigger_step
(
*
fetch_data
)
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
...
...
tensorpack/train/feedfree.py
View file @
f3644ce9
...
@@ -46,25 +46,9 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -46,25 +46,9 @@ class FeedfreeTrainerBase(Trainer):
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainerBase
):
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient"""
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
opt
=
self
.
config
.
optimizer
# GATE_NONE faster?
grads
=
opt
.
compute_gradients
(
cost
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
True
)
return
cost
,
grads
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run ``self.train_op``, which minimizes the cost."""
""" Simply run ``self.train_op``, which minimizes the cost."""
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
get_extra_fetches
())
self
.
hooked_sess
.
run
(
self
.
train_op
)
return
ret
[
1
:]
# if not hasattr(self, 'cnt'):
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# self.cnt = 0
# else:
# else:
...
@@ -83,6 +67,21 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -83,6 +67,21 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
# import sys; sys.exit()
# import sys; sys.exit()
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainerBase
):
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient"""
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
opt
=
self
.
config
.
optimizer
# GATE_NONE faster?
grads
=
opt
.
compute_gradients
(
cost
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
True
)
return
cost
,
grads
class
SimpleFeedfreeTrainer
(
class
SimpleFeedfreeTrainer
(
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
...
...
tensorpack/train/trainer.py
View file @
f3644ce9
...
@@ -87,9 +87,7 @@ class SimpleTrainer(Trainer):
...
@@ -87,9 +87,7 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_method
.
next_feed
()
feed
=
self
.
_input_method
.
next_feed
()
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
get_extra_fetches
(),
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
feed_dict
=
feed
)
return
ret
[
1
:]
def
_setup
(
self
):
def
_setup
(
self
):
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment