Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
24f898ec
Commit
24f898ec
authored
Feb 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
pass trainer to callback
parent
e9a6a5af
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
22 additions
and
27 deletions
+22
-27
example_alexnet.py
example_alexnet.py
+1
-1
example_cifar10.py
example_cifar10.py
+1
-1
example_mnist.py
example_mnist.py
+0
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+6
-2
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+2
-2
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+4
-4
tensorpack/train/base.py
tensorpack/train/base.py
+5
-6
tensorpack/train/train.py
tensorpack/train/train.py
+2
-3
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+0
-6
No files found.
example_alexnet.py
View file @
24f898ec
...
...
@@ -109,7 +109,7 @@ def get_config():
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
S
ummaryWri
ter
(),
S
tatPrin
ter
(),
PeriodicSaver
(),
#ValidationError(dataset_test, prefix='test'),
]),
...
...
example_cifar10.py
View file @
24f898ec
...
...
@@ -131,7 +131,7 @@ def get_config():
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
S
ummaryWri
ter
(),
S
tatPrin
ter
(),
PeriodicSaver
(),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
]),
...
...
example_mnist.py
View file @
24f898ec
...
...
@@ -92,7 +92,6 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
#step_per_epoch = 20
# prepare session
sess_config
=
get_default_sess_config
()
...
...
tensorpack/callbacks/base.py
View file @
24f898ec
...
...
@@ -27,7 +27,8 @@ class Callback(object):
Either TrainCallback or TestCallback
"""
def
before_train
(
self
):
def
before_train
(
self
,
trainer
):
self
.
trainer
=
trainer
self
.
graph
=
tf
.
get_default_graph
()
self
.
sess
=
tf
.
get_default_session
()
self
.
epoch_num
=
0
...
...
@@ -52,9 +53,12 @@ class Callback(object):
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
@
property
def
global_step
(
self
):
return
self
.
trainer
.
global_step
def
trigger_epoch
(
self
):
self
.
epoch_num
+=
1
self
.
global_step
=
get_global_step
()
self
.
_trigger_epoch
()
def
_trigger_epoch
(
self
):
...
...
tensorpack/callbacks/group.py
View file @
24f898ec
...
...
@@ -110,10 +110,10 @@ class Callbacks(Callback):
def
_before_train
(
self
):
for
cb
in
self
.
cbs
:
if
isinstance
(
cb
.
type
,
TrainCallback
):
cb
.
before_train
()
cb
.
before_train
(
self
.
trainer
)
else
:
with
self
.
test_callback_context
.
before_train_context
():
cb
.
before_train
()
cb
.
before_train
(
self
.
trainer
)
def
_after_train
(
self
):
for
cb
in
self
.
cbs
:
...
...
tensorpack/callbacks/summary.py
View file @
24f898ec
...
...
@@ -56,4 +56,4 @@ class StatPrinter(Callback):
self
.
print_tag
=
print_tag
def
_before_train
(
self
):
logg
er
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
,
self
.
print_tag
)
self
.
train
er
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
,
self
.
print_tag
)
tensorpack/callbacks/validation_callback.py
View file @
24f898ec
...
...
@@ -63,9 +63,9 @@ class ValidationCallback(PeriodicCallback):
pbar
.
update
()
cost_avg
=
cost_sum
/
cnt
logger
.
writer
.
add_summary
(
create_summary
(
self
.
trainer
.
summary_
writer
.
add_summary
(
create_summary
(
'{}_cost'
.
format
(
self
.
prefix
),
cost_avg
),
self
.
global_step
)
logg
er
.
stat_holder
.
add_stat
(
"{}_cost"
.
format
(
self
.
prefix
),
cost_avg
)
self
.
train
er
.
stat_holder
.
add_stat
(
"{}_cost"
.
format
(
self
.
prefix
),
cost_avg
)
def
_trigger_periodic
(
self
):
for
dp
,
outputs
in
self
.
_run_validation
():
...
...
@@ -101,6 +101,6 @@ class ValidationError(ValidationCallback):
wrong
=
outputs
[
0
]
err_stat
.
feed
(
wrong
,
batch_size
)
logger
.
writer
.
add_summary
(
create_summary
(
self
.
trainer
.
summary_
writer
.
add_summary
(
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
err_stat
.
accuracy
),
self
.
global_step
)
logg
er
.
stat_holder
.
add_stat
(
"{}_error"
.
format
(
self
.
prefix
),
err_stat
.
accuracy
)
self
.
train
er
.
stat_holder
.
add_stat
(
"{}_error"
.
format
(
self
.
prefix
),
err_stat
.
accuracy
)
tensorpack/train/base.py
View file @
24f898ec
...
...
@@ -35,11 +35,10 @@ class Trainer(object):
pass
def
trigger_epoch
(
self
):
self
.
global_step
+=
self
.
config
.
step_per_epoch
self
.
_trigger_epoch
()
self
.
config
.
callbacks
.
trigger_epoch
()
self
.
summary_writer
.
flush
()
logger
.
stat_holder
.
finalize
()
self
.
stat_holder
.
finalize
()
@
abstractmethod
def
_trigger_epoch
(
self
):
...
...
@@ -50,17 +49,16 @@ class Trainer(object):
raise
RuntimeError
(
"Please use logger.set_logger_dir at the beginning of your script."
)
self
.
summary_writer
=
tf
.
train
.
SummaryWriter
(
logger
.
LOG_DIR
,
graph_def
=
self
.
sess
.
graph_def
)
logger
.
writer
=
self
.
summary_writer
self
.
summary_op
=
tf
.
merge_all_summaries
()
# create an empty StatHolder
logger
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
,
[])
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
,
[])
def
_process_summary
(
self
,
summary_str
):
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
# TODO move to subclasses
logger
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
def
main_loop
(
self
):
...
...
@@ -70,7 +68,7 @@ class Trainer(object):
self
.
_init_summary
()
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
callbacks
.
before_train
()
callbacks
.
before_train
(
self
)
tf
.
get_default_graph
()
.
finalize
()
for
epoch
in
xrange
(
1
,
self
.
config
.
max_epoch
):
...
...
@@ -85,6 +83,7 @@ class Trainer(object):
return
self
.
run_step
()
callbacks
.
trigger_step
()
self
.
global_step
+=
1
self
.
trigger_epoch
()
except
(
KeyboardInterrupt
,
Exception
):
raise
...
...
tensorpack/train/train.py
View file @
24f898ec
...
...
@@ -13,7 +13,7 @@ from ..utils import *
from
..utils.concurrency
import
EnqueueThread
from
..utils.summary
import
summary_moving_average
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
'start_train'
]
def
summary_grads
(
grads
):
for
grad
,
var
in
grads
:
...
...
@@ -157,7 +157,6 @@ class QueueInputTrainer(Trainer):
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
def
start_train
(
config
):
tr
=
Simple
Trainer
(
config
)
tr
=
QueueInput
Trainer
(
config
)
tr
.
train
()
tensorpack/utils/logger.py
View file @
24f898ec
...
...
@@ -83,9 +83,3 @@ unless you're resuming from a previous task.""".format(dirname))
# export logger functions
for
func
in
[
'info'
,
'warning'
,
'error'
,
'critical'
,
'warn'
,
'exception'
,
'debug'
]:
locals
()[
func
]
=
getattr
(
logger
,
func
)
# a global SummaryWriter
writer
=
None
# a global StatHolder
stat_holder
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment