Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2264b5a3
Commit
2264b5a3
authored
Feb 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
both trainer works
parent
ea72115e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
51 deletions
+61
-51
example_mnist.py
example_mnist.py
+2
-2
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+0
-12
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+4
-22
tensorpack/train.py
tensorpack/train.py
+55
-15
No files found.
example_mnist.py
View file @
2264b5a3
...
@@ -92,7 +92,7 @@ def get_config():
...
@@ -92,7 +92,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
#
step_per_epoch = 20
step_per_epoch
=
20
# prepare session
# prepare session
sess_config
=
get_default_sess_config
()
sess_config
=
get_default_sess_config
()
...
@@ -109,7 +109,7 @@ def get_config():
...
@@ -109,7 +109,7 @@ def get_config():
dataset
=
dataset_train
,
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
Callbacks
([
callbacks
=
Callbacks
([
S
ummaryWri
ter
(),
S
tatPrin
ter
(),
PeriodicSaver
(),
PeriodicSaver
(),
ValidationError
(
dataset_test
,
prefix
=
'validation'
),
ValidationError
(
dataset_test
,
prefix
=
'validation'
),
]),
]),
...
...
tensorpack/callbacks/group.py
View file @
2264b5a3
...
@@ -104,15 +104,6 @@ class Callbacks(Callback):
...
@@ -104,15 +104,6 @@ class Callbacks(Callback):
raise
ValueError
(
raise
ValueError
(
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
"Unknown callback running graph {}!"
.
format
(
str
(
cb
.
type
)))
# ensure a SummaryWriter
for
idx
,
cb
in
enumerate
(
cbs
):
if
type
(
cb
)
==
SummaryWriter
:
cbs
.
insert
(
0
,
cbs
.
pop
(
idx
))
break
else
:
logger
.
warn
(
"SummaryWriter must be used! Insert a default one automatically."
)
cbs
.
insert
(
0
,
SummaryWriter
())
self
.
cbs
=
cbs
self
.
cbs
=
cbs
self
.
test_callback_context
=
TestCallbackContext
()
self
.
test_callback_context
=
TestCallbackContext
()
...
@@ -127,7 +118,6 @@ class Callbacks(Callback):
...
@@ -127,7 +118,6 @@ class Callbacks(Callback):
def
_after_train
(
self
):
def
_after_train
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
cb
.
after_train
()
cb
.
after_train
()
logger
.
writer
.
close
()
def
trigger_step
(
self
):
def
trigger_step
(
self
):
for
cb
in
self
.
cbs
:
for
cb
in
self
.
cbs
:
...
@@ -152,5 +142,3 @@ class Callbacks(Callback):
...
@@ -152,5 +142,3 @@ class Callbacks(Callback):
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
tm
.
timed_callback
(
type
(
cb
)
.
__name__
):
cb
.
trigger_epoch
()
cb
.
trigger_epoch
()
tm
.
log
()
tm
.
log
()
logger
.
writer
.
flush
()
logger
.
stat_holder
.
finalize
()
tensorpack/callbacks/summary.py
View file @
2264b5a3
...
@@ -12,7 +12,7 @@ import pickle
...
@@ -12,7 +12,7 @@ import pickle
from
.base
import
Callback
,
PeriodicCallback
from
.base
import
Callback
,
PeriodicCallback
from
..utils
import
*
from
..utils
import
*
__all__
=
[
'S
ummaryWri
ter'
]
__all__
=
[
'S
tatHolder'
,
'StatPrin
ter'
]
class
StatHolder
(
object
):
class
StatHolder
(
object
):
def
__init__
(
self
,
log_dir
,
print_tag
=
None
):
def
__init__
(
self
,
log_dir
,
print_tag
=
None
):
...
@@ -48,30 +48,12 @@ class StatHolder(object):
...
@@ -48,30 +48,12 @@ class StatHolder(object):
pickle
.
dump
(
self
.
stat_history
,
f
)
pickle
.
dump
(
self
.
stat_history
,
f
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
class
S
ummaryWri
ter
(
Callback
):
class
S
tatPrin
ter
(
Callback
):
def
__init__
(
self
,
print_tag
=
None
):
def
__init__
(
self
,
print_tag
=
None
):
""" print_tag : a list of regex to match scalar summary to print
""" print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags
if None, will print all scalar tags
"""
"""
if
not
hasattr
(
logger
,
'LOG_DIR'
):
self
.
print_tag
=
print_tag
raise
RuntimeError
(
"Please use logger.set_logger_dir at the beginning of your script."
)
self
.
log_dir
=
logger
.
LOG_DIR
logger
.
stat_holder
=
StatHolder
(
self
.
log_dir
,
print_tag
)
def
_before_train
(
self
):
def
_before_train
(
self
):
logger
.
writer
=
tf
.
train
.
SummaryWriter
(
logger
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
,
self
.
print_tag
)
self
.
log_dir
,
graph_def
=
self
.
sess
.
graph_def
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
def
_trigger_epoch
(
self
):
# check if there is any summary to write
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
()
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
logger
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
logger
.
writer
.
add_summary
(
summary
,
self
.
global_step
)
tensorpack/train.py
View file @
2264b5a3
...
@@ -107,34 +107,65 @@ class Trainer(object):
...
@@ -107,34 +107,65 @@ class Trainer(object):
def
run_step
(
self
):
def
run_step
(
self
):
pass
pass
def
trigger_epoch
(
self
):
self
.
global_step
+=
self
.
config
.
step_per_epoch
self
.
_trigger_epoch
()
self
.
config
.
callbacks
.
trigger_epoch
()
self
.
summary_writer
.
flush
()
logger
.
stat_holder
.
finalize
()
@
abstractmethod
def
_trigger_epoch
(
self
):
pass
def
_init_summary
(
self
):
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"Please use logger.set_logger_dir at the beginning of your script."
)
self
.
summary_writer
=
tf
.
train
.
SummaryWriter
(
logger
.
LOG_DIR
,
graph_def
=
self
.
sess
.
graph_def
)
logger
.
writer
=
self
.
summary_writer
self
.
summary_op
=
tf
.
merge_all_summaries
()
# create an empty StatHolder
logger
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
,
[])
def
_process_summary
(
self
,
summary_str
):
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]*/'
,
''
,
val
.
tag
)
# TODO move to subclasses
logger
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
def
main_loop
(
self
):
def
main_loop
(
self
):
callbacks
=
self
.
config
.
callbacks
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
try
:
try
:
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
self
.
_init_summary
()
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
callbacks
.
before_train
()
callbacks
.
before_train
()
tf
.
get_default_graph
()
.
finalize
()
tf
.
get_default_graph
()
.
finalize
()
for
epoch
in
xrange
(
1
,
self
.
config
.
max_epoch
):
for
epoch
in
xrange
(
1
,
self
.
config
.
max_epoch
):
with
timed_operation
(
with
timed_operation
(
'Epoch {}, global_step={}'
.
format
(
'Epoch {}, global_step={}'
.
format
(
epoch
,
get_global_step
()
+
self
.
config
.
step_per_epoch
)):
epoch
,
self
.
global_step
+
self
.
config
.
step_per_epoch
)):
for
step
in
tqdm
.
trange
(
for
step
in
tqdm
.
trange
(
self
.
config
.
step_per_epoch
,
self
.
config
.
step_per_epoch
,
leave
=
True
,
mininterval
=
0.5
,
leave
=
True
,
mininterval
=
0.5
,
dynamic_ncols
=
True
,
ascii
=
True
):
dynamic_ncols
=
True
,
ascii
=
True
):
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
return
return
self
.
run_step
()
self
.
run_step
()
callbacks
.
trigger_step
()
callbacks
.
trigger_step
()
# note that summary_op will take a data from the queue
self
.
trigger_epoch
()
callbacks
.
trigger_epoch
()
except
(
KeyboardInterrupt
,
Exception
):
except
(
KeyboardInterrupt
,
Exception
):
raise
raise
finally
:
finally
:
self
.
coord
.
request_stop
()
self
.
coord
.
request_stop
()
# Do I need to run queue.close?
# Do I need to run queue.close?
callbacks
.
after_train
()
callbacks
.
after_train
()
self
.
summary_writer
.
close
()
self
.
sess
.
close
()
self
.
sess
.
close
()
def
init_session_and_coord
(
self
):
def
init_session_and_coord
(
self
):
...
@@ -147,14 +178,9 @@ class Trainer(object):
...
@@ -147,14 +178,9 @@ class Trainer(object):
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
def
run_step
(
self
):
def
run_step
(
self
):
try
:
data
=
next
(
self
.
data_producer
)
data
=
next
(
self
.
data_producer
)
except
StopIteration
:
self
.
data_producer
=
self
.
config
.
dataset
.
get_data
()
data
=
next
(
self
.
data_producer
)
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
...
@@ -176,9 +202,17 @@ class SimpleTrainer(Trainer):
...
@@ -176,9 +202,17 @@ class SimpleTrainer(Trainer):
describe_model
()
describe_model
()
self
.
init_session_and_coord
()
self
.
init_session_and_coord
()
self
.
data_producer
=
self
.
config
.
dataset
.
get_data
()
# create an infinte data producer
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
main_loop
()
self
.
main_loop
()
def
_trigger_epoch
(
self
):
if
self
.
summary_op
is
not
None
:
data
=
next
(
self
.
data_producer
)
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
self
.
_process_summary
(
summary_str
)
class
QueueInputTrainer
(
Trainer
):
class
QueueInputTrainer
(
Trainer
):
"""
"""
...
@@ -257,6 +291,12 @@ class QueueInputTrainer(Trainer):
...
@@ -257,6 +291,12 @@ class QueueInputTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
self
.
sess
.
run
([
self
.
train_op
])
# faster since train_op return None
self
.
sess
.
run
([
self
.
train_op
])
# faster since train_op return None
def
_trigger_epoch
(
self
):
# note that summary_op will take a data from the queue
if
self
.
summary_op
is
not
None
:
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
def
start_train
(
config
):
def
start_train
(
config
):
#if config.model.get_input_queue() is not None:
#if config.model.get_input_queue() is not None:
...
@@ -264,6 +304,6 @@ def start_train(config):
...
@@ -264,6 +304,6 @@ def start_train(config):
#tr = QueueInputTrainer()
#tr = QueueInputTrainer()
#else:
#else:
#tr = SimpleTrainer()
#tr = SimpleTrainer()
#
tr = SimpleTrainer(config)
tr
=
SimpleTrainer
(
config
)
tr
=
QueueInputTrainer
(
config
)
#
tr = QueueInputTrainer(config)
tr
.
train
()
tr
.
train
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment