Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0f4eda16
Commit
0f4eda16
authored
Feb 26, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use Monitors as the backend for all the summaries and stats
parent
61f14083
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
246 additions
and
206 deletions
+246
-206
examples/PennTreebank/PTB-LSTM.py
examples/PennTreebank/PTB-LSTM.py
+4
-4
examples/SimilarityLearning/mnist-embeddings.py
examples/SimilarityLearning/mnist-embeddings.py
+0
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+1
-1
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-2
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+1
-1
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+8
-134
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+2
-2
tensorpack/train/base.py
tensorpack/train/base.py
+20
-51
tensorpack/train/config.py
tensorpack/train/config.py
+14
-9
tensorpack/train/monitor.py
tensorpack/train/monitor.py
+194
-0
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-1
No files found.
examples/PennTreebank/PTB-LSTM.py
View file @
0f4eda16
...
@@ -146,12 +146,12 @@ def get_config():
...
@@ -146,12 +146,12 @@ def get_config():
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
RunOp
(
lambda
:
M
.
reset_lstm_state
()),
CallbackFactory
(
CallbackFactory
(
trigger_epoch
=
lambda
self
:
trigger_epoch
=
lambda
self
:
[
self
.
trainer
.
add_scalar_summary
(
[
self
.
trainer
.
monitors
.
put
(
'validation_perplexity'
,
'validation_perplexity'
,
np
.
exp
(
self
.
trainer
.
stat_holder
.
get_stat_now
(
'validation_cost'
)
/
SEQ_LEN
)),
np
.
exp
(
self
.
trainer
.
monitors
.
get_latest
(
'validation_cost'
)
/
SEQ_LEN
)),
self
.
trainer
.
add_scalar_summary
(
self
.
trainer
.
monitors
.
put
(
'test_perplexity'
,
'test_perplexity'
,
np
.
exp
(
self
.
trainer
.
stat_holder
.
get_stat_now
(
'test_cost'
)
/
SEQ_LEN
))]
np
.
exp
(
self
.
trainer
.
monitors
.
get_latest
(
'test_cost'
)
/
SEQ_LEN
))]
),
),
],
],
max_epoch
=
70
,
max_epoch
=
70
,
...
...
examples/SimilarityLearning/mnist-embeddings.py
View file @
0f4eda16
...
@@ -152,7 +152,6 @@ def get_config(model, algorithm_name):
...
@@ -152,7 +152,6 @@ def get_config(model, algorithm_name):
MovingAverageSummary
(),
MovingAverageSummary
(),
ProgressBar
(
extra_display
),
ProgressBar
(
extra_display
),
MergeAllSummaries
(),
MergeAllSummaries
(),
StatPrinter
()
],
],
max_epoch
=
20
,
max_epoch
=
20
,
)
)
...
...
tensorpack/callbacks/inference_runner.py
View file @
0f4eda16
...
@@ -63,7 +63,7 @@ def summary_inferencer(trainer, infs):
...
@@ -63,7 +63,7 @@ def summary_inferencer(trainer, infs):
except
:
except
:
logger
.
warn
(
"{} returns a non-scalar statistics!"
.
format
(
type
(
inf
)
.
__name__
))
logger
.
warn
(
"{} returns a non-scalar statistics!"
.
format
(
type
(
inf
)
.
__name__
))
continue
continue
trainer
.
add_scalar_summary
(
k
,
v
)
trainer
.
monitors
.
put
(
k
,
v
)
class
InferenceRunner
(
Triggerable
):
class
InferenceRunner
(
Triggerable
):
...
...
tensorpack/callbacks/param.py
View file @
0f4eda16
...
@@ -318,8 +318,7 @@ class StatMonitorParamSetter(HyperParamSetter):
...
@@ -318,8 +318,7 @@ class StatMonitorParamSetter(HyperParamSetter):
self
.
last_changed_epoch
=
0
self
.
last_changed_epoch
=
0
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
holder
=
self
.
trainer
.
stat_holder
hist
=
self
.
trainer
.
monitors
.
get_history
(
self
.
stat_name
)
hist
=
holder
.
get_stat_history
(
self
.
stat_name
)
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
if
len
(
hist
)
<
self
.
last_k
+
1
or
\
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
last_k
:
self
.
epoch_num
-
self
.
last_changed_epoch
<
self
.
last_k
:
return
None
return
None
...
...
tensorpack/callbacks/saver.py
View file @
0f4eda16
...
@@ -98,7 +98,7 @@ class MinSaver(Triggerable):
...
@@ -98,7 +98,7 @@ class MinSaver(Triggerable):
def
_get_stat
(
self
):
def
_get_stat
(
self
):
try
:
try
:
v
=
self
.
trainer
.
stat_holder
.
get_stat_now
(
self
.
monitor_stat
)
v
=
self
.
trainer
.
monitors
.
get_latest
(
self
.
monitor_stat
)
except
KeyError
:
except
KeyError
:
v
=
None
v
=
None
return
v
return
v
...
...
tensorpack/callbacks/stats.py
View file @
0f4eda16
...
@@ -3,148 +3,22 @@
...
@@ -3,148 +3,22 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
import
os
import
operator
import
json
from
.base
import
Triggerable
from
.base
import
Triggerable
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
log_deprecated
__all__
=
[
'StatHolder'
,
'StatPrinter'
,
'SendStat'
]
__all__
=
[
'StatPrinter'
,
'SendStat'
]
class
StatHolder
(
object
):
"""
A holder to keep all statistics aside from tensorflow events.
"""
def
__init__
(
self
,
log_dir
):
"""
Args:
log_dir(str): directory to save the stats.
"""
self
.
set_print_tag
([])
self
.
blacklist_tag
=
set
()
self
.
stat_now
=
{}
self
.
log_dir
=
log_dir
self
.
filename
=
os
.
path
.
join
(
log_dir
,
'stat.json'
)
if
os
.
path
.
isfile
(
self
.
filename
):
# TODO make a backup first?
logger
.
info
(
"Found stats at {}, will append to it."
.
format
(
self
.
filename
))
with
open
(
self
.
filename
)
as
f
:
self
.
stat_history
=
json
.
load
(
f
)
else
:
self
.
stat_history
=
[]
# global step of the current list of stat
self
.
_current_gs
=
-
1
def
add_stat
(
self
,
k
,
v
,
global_step
,
epoch_num
):
"""
Add a stat.
"""
if
global_step
!=
self
.
_current_gs
:
self
.
_push
()
self
.
_current_gs
=
global_step
self
.
stat_now
[
'epoch_num'
]
=
epoch_num
self
.
stat_now
[
'global_step'
]
=
global_step
self
.
stat_now
[
k
]
=
float
(
v
)
def
set_print_tag
(
self
,
print_tag
):
"""
Set name of stats to print.
Args:
print_tag: a collection of string.
"""
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
def
add_blacklist_tag
(
self
,
blacklist_tag
):
""" Disable printing for some tags
Args:
blacklist_tag: a collection of string.
"""
self
.
blacklist_tag
|=
set
(
blacklist_tag
)
def
get_stat_now
(
self
,
key
):
"""
Return the value of a stat in the current epoch.
Raises:
KeyError if the key hasn't been added in this epoch.
"""
return
self
.
stat_now
[
key
]
def
get_stat_history
(
self
,
key
):
"""
Returns:
list: all history of a stat. Empty if there is not history of this name.
"""
ret
=
[]
for
h
in
self
.
stat_history
:
v
=
h
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
v
=
self
.
stat_now
.
get
(
key
,
None
)
if
v
is
not
None
:
ret
.
append
(
v
)
return
ret
def
finalize
(
self
):
"""
Print and write stats to disk.
This method is idempotent.
"""
self
.
_print_stat
()
self
.
_push
()
def
_push
(
self
):
""" Note that this method is idempotent"""
if
len
(
self
.
stat_now
):
self
.
stat_history
.
append
(
self
.
stat_now
)
self
.
stat_now
=
{}
self
.
_write_stat
()
def
_print_stat
(
self
):
for
k
,
v
in
sorted
(
self
.
stat_now
.
items
(),
key
=
operator
.
itemgetter
(
0
)):
if
self
.
print_tag
is
None
or
k
in
self
.
print_tag
:
if
k
not
in
self
.
blacklist_tag
:
logger
.
info
(
'{}: {:.5g}'
.
format
(
k
,
v
))
def
_write_stat
(
self
):
tmp_filename
=
self
.
filename
+
'.tmp'
try
:
with
open
(
tmp_filename
,
'w'
)
as
f
:
json
.
dump
(
self
.
stat_history
,
f
)
os
.
rename
(
tmp_filename
,
self
.
filename
)
except
IOError
:
# disk error sometimes..
logger
.
exception
(
"Exception in StatHolder.finalize()!"
)
class
StatPrinter
(
Triggerable
):
class
StatPrinter
(
Triggerable
):
"""
A callback to control what stats to print. Enable by default to print
everything in trainer.stat_holder.
"""
def
__init__
(
self
,
print_tag
=
None
):
def
__init__
(
self
,
print_tag
=
None
):
"""
log_deprecated
(
"StatPrinter"
,
Args:
"No need to add StatPrinter to callbacks anymore!"
,
print_tag: a list of stat names to print.
"2017-03-26"
)
If None, will print all scalar tags.
"""
self
.
print_tag
=
print_tag
def
_before_train
(
self
):
self
.
_stat_holder
=
self
.
trainer
.
stat_holder
self
.
_stat_holder
.
set_print_tag
(
self
.
print_tag
)
self
.
_stat_holder
.
add_blacklist_tag
([
'global_step'
,
'epoch_num'
])
def
_trigger
(
self
):
self
.
_stat_holder
.
finalize
()
# TODO make it into monitor?
class
SendStat
(
Triggerable
):
class
SendStat
(
Triggerable
):
"""
"""
Execute a command with some specific stats.
Execute a command with some specific stats.
...
@@ -173,8 +47,8 @@ class SendStat(Triggerable):
...
@@ -173,8 +47,8 @@ class SendStat(Triggerable):
self
.
stats
=
stats
self
.
stats
=
stats
def
_trigger
(
self
):
def
_trigger
(
self
):
holder
=
self
.
trainer
.
stat_holder
M
=
self
.
trainer
.
monitors
v
=
{
k
:
holder
.
get_stat_now
(
k
)
for
k
in
self
.
stats
}
v
=
{
k
:
M
.
get_latest
(
k
)
for
k
in
self
.
stats
}
cmd
=
self
.
command
.
format
(
**
v
)
cmd
=
self
.
command
.
format
(
**
v
)
ret
=
os
.
system
(
cmd
)
ret
=
os
.
system
(
cmd
)
if
ret
!=
0
:
if
ret
!=
0
:
...
...
tensorpack/callbacks/summary.py
View file @
0f4eda16
...
@@ -68,9 +68,9 @@ class MergeAllSummaries(Callback):
...
@@ -68,9 +68,9 @@ class MergeAllSummaries(Callback):
summary
=
run_values
.
results
summary
=
run_values
.
results
if
summary
is
None
:
if
summary
is
None
:
return
return
self
.
trainer
.
add
_summary
(
summary
)
self
.
trainer
.
monitors
.
put
_summary
(
summary
)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
_run_alone
:
if
self
.
_run_alone
:
summary
=
self
.
summary_op
.
eval
()
summary
=
self
.
summary_op
.
eval
()
self
.
trainer
.
add
_summary
(
summary
)
self
.
trainer
.
monitors
.
put
_summary
(
summary
)
tensorpack/train/base.py
View file @
0f4eda16
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
re
import
time
import
time
import
weakref
import
weakref
import
six
import
six
...
@@ -15,12 +14,12 @@ from tensorflow.python.training.monitored_session \
...
@@ -15,12 +14,12 @@ from tensorflow.python.training.monitored_session \
from
.predict
import
PredictorFactory
from
.predict
import
PredictorFactory
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
.monitor
import
Monitors
,
TrainingMonitor
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
deprecated
,
log_deprecated
from
..utils.develop
import
deprecated
,
log_deprecated
from
..callbacks
import
StatHolder
,
Callback
,
Callbacks
,
MaintainStepCounter
from
..callbacks
import
Callback
,
Callbacks
,
MaintainStepCounter
from
..tfutils
import
get_global_step_value
from
..tfutils
import
get_global_step_value
from
..tfutils.modelutils
import
describe_model
from
..tfutils.modelutils
import
describe_model
from
..tfutils.summary
import
create_scalar_summary
__all__
=
[
'Trainer'
,
'StopTraining'
,
'MultiPredictorTowerTrainer'
]
__all__
=
[
'Trainer'
,
'StopTraining'
,
'MultiPredictorTowerTrainer'
]
...
@@ -40,9 +39,7 @@ class Trainer(object):
...
@@ -40,9 +39,7 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer.
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
model (ModelDesc)
sess (tf.Session): the current session in use.
sess (tf.Session): the current session in use.
monitors (Monitors): the monitors
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
epoch_num (int): the number of epochs that have finished.
epoch_num (int): the number of epochs that have finished.
local_step (int): the number of steps that have finished in the current epoch.
local_step (int): the number of steps that have finished in the current epoch.
...
@@ -65,6 +62,8 @@ class Trainer(object):
...
@@ -65,6 +62,8 @@ class Trainer(object):
for
cb
in
self
.
config
.
callbacks
:
for
cb
in
self
.
config
.
callbacks
:
self
.
register_callback
(
cb
)
self
.
register_callback
(
cb
)
self
.
monitors
=
config
.
monitors
def
register_callback
(
self
,
cb
):
def
register_callback
(
self
,
cb
):
"""
"""
Use this method before :meth:`Trainer._setup` finishes,
Use this method before :meth:`Trainer._setup` finishes,
...
@@ -78,6 +77,12 @@ class Trainer(object):
...
@@ -78,6 +77,12 @@ class Trainer(object):
"Cannot register more callbacks after trainer was setup!"
"Cannot register more callbacks after trainer was setup!"
self
.
_callbacks
.
append
(
cb
)
self
.
_callbacks
.
append
(
cb
)
def
register_monitor
(
self
,
mon
):
assert
isinstance
(
mon
,
TrainingMonitor
),
mon
assert
not
isinstance
(
self
.
monitors
,
Monitors
),
\
"Cannot register more monitors after trainer was setup!"
self
.
monitors
.
append
(
mon
)
def
train
(
self
):
def
train
(
self
):
""" Start training """
""" Start training """
self
.
setup
()
self
.
setup
()
...
@@ -88,48 +93,9 @@ class Trainer(object):
...
@@ -88,48 +93,9 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration".
""" Abstract method: run one iteration. Subclass should define what is "iteration".
"""
"""
def
trigger_epoch
(
self
):
"""
Called after each epoch.
"""
# trigger subclass
self
.
_trigger_epoch
()
# trigger callbacks
self
.
_callbacks
.
trigger_epoch
()
self
.
summary_writer
.
flush
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
pass
pass
def
add_summary
(
self
,
summary
):
"""
Add summary to ``self.summary_writer``, and also
add scalar summary to ``self.stat_holder``.
Args:
summary (tf.Summary or str): a summary object, or a str which will
be interpreted as a serialized tf.Summary protobuf.
"""
if
isinstance
(
summary
,
six
.
binary_type
):
summary
=
tf
.
Summary
.
FromString
(
summary
)
assert
isinstance
(
summary
,
tf
.
Summary
),
type
(
summary
)
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
suffix
=
'-summary'
# issue#6150
if
val
.
tag
.
endswith
(
suffix
):
val
.
tag
=
val
.
tag
[:
-
len
(
suffix
)]
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
,
self
.
global_step
,
self
.
epoch_num
)
self
.
summary_writer
.
add_summary
(
summary
,
get_global_step_value
())
def
add_scalar_summary
(
self
,
name
,
val
):
"""
Add a scalar summary to both TF events file and StatHolder.
"""
self
.
add_summary
(
create_scalar_summary
(
name
,
val
))
def
setup
(
self
):
def
setup
(
self
):
"""
"""
Setup the trainer and be ready for the main loop.
Setup the trainer and be ready for the main loop.
...
@@ -141,10 +107,9 @@ class Trainer(object):
...
@@ -141,10 +107,9 @@ class Trainer(object):
describe_model
()
describe_model
()
# some final operations that might modify the graph
# some final operations that might modify the graph
logger
.
info
(
"Setup summaries ..."
)
logger
.
info
(
"Setup monitors ..."
)
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
tf
.
get_default_graph
())
self
.
monitors
=
Monitors
(
self
.
monitors
)
# create an empty StatHolder
self
.
monitors
.
setup
(
weakref
.
proxy
(
self
))
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
logger
.
info
(
"Setup callbacks graph ..."
)
logger
.
info
(
"Setup callbacks graph ..."
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
...
@@ -202,7 +167,10 @@ class Trainer(object):
...
@@ -202,7 +167,10 @@ class Trainer(object):
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
logger
.
info
(
"Epoch {} (global_step {}) finished, time:{:.2f} sec."
.
format
(
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
self
.
epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
self
.
trigger_epoch
()
# trigger epoch outside the timing region.
# trigger epoch outside the timing region.
self
.
_trigger_epoch
()
self
.
_callbacks
.
trigger_epoch
()
self
.
monitors
.
flush
()
except
StopTraining
:
except
StopTraining
:
logger
.
info
(
"Training was stopped."
)
logger
.
info
(
"Training was stopped."
)
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
...
@@ -211,9 +179,10 @@ class Trainer(object):
...
@@ -211,9 +179,10 @@ class Trainer(object):
raise
raise
finally
:
finally
:
self
.
_callbacks
.
after_train
()
self
.
_callbacks
.
after_train
()
self
.
summary_writer
.
close
()
self
.
monitors
.
close
()
self
.
monitored_sess
.
close
()
self
.
monitored_sess
.
close
()
# Predictor related methods: TODO
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
"""
Args:
Args:
...
...
tensorpack/train/config.py
View file @
0f4eda16
...
@@ -6,7 +6,7 @@ import tensorflow as tf
...
@@ -6,7 +6,7 @@ import tensorflow as tf
from
..callbacks
import
(
from
..callbacks
import
(
Callbacks
,
MovingAverageSummary
,
Callbacks
,
MovingAverageSummary
,
StatPrinter
,
ProgressBar
,
MergeAllSummaries
)
ProgressBar
,
MergeAllSummaries
)
from
..dataflow.base
import
DataFlow
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..models
import
ModelDesc
from
..utils
import
logger
from
..utils
import
logger
...
@@ -15,6 +15,7 @@ from ..tfutils import (JustCurrentSession,
...
@@ -15,6 +15,7 @@ from ..tfutils import (JustCurrentSession,
get_default_sess_config
,
SessionInit
)
get_default_sess_config
,
SessionInit
)
from
..tfutils.optimizer
import
apply_grad_processors
from
..tfutils.optimizer
import
apply_grad_processors
from
.input_data
import
InputData
from
.input_data
import
InputData
from
.monitor
import
TFSummaryWriter
,
JSONWriter
,
ScalarPrinter
__all__
=
[
'TrainConfig'
]
__all__
=
[
'TrainConfig'
]
...
@@ -24,11 +25,12 @@ class TrainConfig(object):
...
@@ -24,11 +25,12 @@ class TrainConfig(object):
Config for trainer.
Config for trainer.
"""
"""
def
__init__
(
self
,
dataflow
=
None
,
data
=
None
,
def
__init__
(
self
,
dataflow
=
None
,
data
=
None
,
model
=
None
,
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
session_config
=
get_default_sess_config
()
,
monitors
=
None
,
session_init
=
None
,
session_
config
=
get_default_sess_config
(),
session_
init
=
None
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
,
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
[
0
],
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
[
0
],
**
kwargs
):
**
kwargs
):
...
@@ -41,10 +43,10 @@ class TrainConfig(object):
...
@@ -41,10 +43,10 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training.
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
is only used to provide the defaults. The defaults are
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries()
, StatPrinter()
]``. The list of
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries()]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
Note that ``StatPrinter`` should be the last one to be able to print
monitors (list): a list of :class:`TrainingMonitor`.
stats generated by other callbacks
.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``
.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
starting_epoch (int): The index of the first epoch.
...
@@ -86,11 +88,14 @@ class TrainConfig(object):
...
@@ -86,11 +88,14 @@ class TrainConfig(object):
extra_callbacks
=
[
extra_callbacks
=
[
MovingAverageSummary
(),
MovingAverageSummary
(),
ProgressBar
(),
ProgressBar
(),
MergeAllSummaries
(),
MergeAllSummaries
()]
StatPrinter
()]
self
.
callbacks
=
callbacks
+
extra_callbacks
self
.
callbacks
=
callbacks
+
extra_callbacks
assert_type
(
self
.
callbacks
,
list
)
assert_type
(
self
.
callbacks
,
list
)
if
monitors
is
None
:
monitors
=
[
TFSummaryWriter
(),
JSONWriter
(),
ScalarPrinter
()]
self
.
monitors
=
monitors
self
.
model
=
model
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
assert_type
(
self
.
model
,
ModelDesc
)
...
...
tensorpack/train/monitor.py
0 → 100644
View file @
0f4eda16
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: monitor.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
import
operator
from
collections
import
defaultdict
import
six
import
json
import
re
import
tensorflow
as
tf
from
..utils
import
logger
__all__
=
[
'TrainingMonitor'
,
'Monitors'
,
'TFSummaryWriter'
,
'JSONWriter'
,
'ScalarPrinter'
]
class
TrainingMonitor
(
object
):
def
setup
(
self
,
trainer
):
self
.
_trainer
=
trainer
def
put_summary
(
self
,
summary
):
pass
def
put
(
self
,
name
,
val
):
# TODO split by types?
pass
def
flush
(
self
):
pass
def
close
(
self
):
pass
class
Monitors
(
TrainingMonitor
):
def
__init__
(
self
,
monitors
):
# TODO filter by names
self
.
_scalar_history
=
ScalarHistory
()
self
.
_monitors
=
monitors
+
[
self
.
_scalar_history
]
def
setup
(
self
,
trainer
):
for
m
in
self
.
_monitors
:
m
.
setup
(
trainer
)
def
flush
(
self
):
for
m
in
self
.
_monitors
:
m
.
flush
()
def
close
(
self
):
for
m
in
self
.
_monitors
:
m
.
close
()
def
_dispatch_put_summary
(
self
,
summary
):
for
m
in
self
.
_monitors
:
m
.
put_summary
(
summary
)
def
_dispatch_put
(
self
,
name
,
val
):
for
m
in
self
.
_monitors
:
m
.
put
(
name
,
val
)
def
put_summary
(
self
,
summary
):
if
isinstance
(
summary
,
six
.
binary_type
):
summary
=
tf
.
Summary
.
FromString
(
summary
)
assert
isinstance
(
summary
,
tf
.
Summary
),
type
(
summary
)
self
.
_dispatch_put_summary
(
summary
)
# TODO other types
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
suffix
=
'-summary'
# issue#6150
if
val
.
tag
.
endswith
(
suffix
):
val
.
tag
=
val
.
tag
[:
-
len
(
suffix
)]
self
.
_dispatch_put
(
val
.
tag
,
val
.
simple_value
)
def
put
(
self
,
name
,
val
):
val
=
float
(
val
)
# TODO only support numeric for now
self
.
_dispatch_put
(
name
,
val
)
s
=
tf
.
Summary
()
s
.
value
.
add
(
tag
=
name
,
simple_value
=
val
)
self
.
_dispatch_put_summary
(
s
)
def
get_latest
(
self
,
name
):
return
self
.
_scalar_history
.
get_latest
(
name
)
def
get_history
(
self
,
name
):
return
self
.
_scalar_history
.
get_history
(
name
)
class
TFSummaryWriter
(
TrainingMonitor
):
def
setup
(
self
,
trainer
):
super
(
TFSummaryWriter
,
self
)
.
setup
(
trainer
)
self
.
_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
tf
.
get_default_graph
())
def
put_summary
(
self
,
summary
):
self
.
_writer
.
add_summary
(
summary
,
self
.
_trainer
.
global_step
)
def
flush
(
self
):
self
.
_writer
.
flush
()
def
close
(
self
):
self
.
_writer
.
close
()
class
JSONWriter
(
TrainingMonitor
):
def
setup
(
self
,
trainer
):
super
(
JSONWriter
,
self
)
.
setup
(
trainer
)
self
.
_dir
=
logger
.
LOG_DIR
self
.
_fname
=
os
.
path
.
join
(
self
.
_dir
,
'stat.json'
)
if
os
.
path
.
isfile
(
self
.
_fname
):
# TODO make a backup first?
logger
.
info
(
"Found existing JSON at {}, will append to it."
.
format
(
self
.
_fname
))
with
open
(
self
.
_fname
)
as
f
:
self
.
_stats
=
json
.
load
(
f
)
assert
isinstance
(
self
.
_stats
,
list
),
type
(
self
.
_stats
)
else
:
self
.
_stats
=
[]
self
.
_stat_now
=
{}
self
.
_last_gs
=
-
1
def
put
(
self
,
name
,
val
):
gs
=
self
.
_trainer
.
global_step
if
gs
!=
self
.
_last_gs
:
self
.
_push
()
self
.
_last_gs
=
gs
self
.
_stat_now
[
'epoch_num'
]
=
self
.
_trainer
.
epoch_num
self
.
_stat_now
[
'global_step'
]
=
gs
self
.
_stat_now
[
name
]
=
float
(
val
)
# TODO will fail for non-numeric
def
_push
(
self
):
""" Note that this method is idempotent"""
if
len
(
self
.
_stat_now
):
self
.
_stats
.
append
(
self
.
_stat_now
)
self
.
_stat_now
=
{}
self
.
_write_stat
()
def
_write_stat
(
self
):
tmp_filename
=
self
.
_fname
+
'.tmp'
try
:
with
open
(
tmp_filename
,
'w'
)
as
f
:
json
.
dump
(
self
.
_stats
,
f
)
os
.
rename
(
tmp_filename
,
self
.
_fname
)
except
IOError
:
# disk error sometimes..
logger
.
exception
(
"Exception in StatHolder.finalize()!"
)
def
flush
(
self
):
self
.
_push
()
# TODO print interval
class
ScalarPrinter
(
TrainingMonitor
):
def
__init__
(
self
):
self
.
_whitelist
=
None
self
.
_blacklist
=
set
([])
def
setup
(
self
,
_
):
self
.
_dic
=
{}
def
put
(
self
,
name
,
val
):
self
.
_dic
[
name
]
=
float
(
val
)
def
_print_stat
(
self
):
for
k
,
v
in
sorted
(
self
.
_dic
.
items
(),
key
=
operator
.
itemgetter
(
0
)):
if
self
.
_whitelist
is
None
or
k
in
self
.
_whitelist
:
if
k
not
in
self
.
_blacklist
:
logger
.
info
(
'{}: {:.5g}'
.
format
(
k
,
v
))
def
flush
(
self
):
self
.
_print_stat
()
self
.
_dic
=
{}
class
ScalarHistory
(
TrainingMonitor
):
def
setup
(
self
,
_
):
self
.
_dic
=
defaultdict
(
list
)
def
put
(
self
,
name
,
val
):
self
.
_dic
[
name
]
.
append
(
float
(
val
))
def
get_latest
(
self
,
name
):
hist
=
self
.
_dic
[
name
]
if
len
(
hist
)
==
0
:
raise
KeyError
(
"Invalid key: {}"
.
format
(
name
))
else
:
return
hist
[
-
1
]
def
get_history
(
self
,
name
):
return
self
.
_dic
[
name
]
tensorpack/train/multigpu.py
View file @
0f4eda16
...
@@ -237,7 +237,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -237,7 +237,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
if
self
.
config
.
nr_tower
>
1
:
if
self
.
config
.
nr_tower
>
1
:
async_step_total_cnt
=
int
(
re
.
findall
(
async_step_total_cnt
=
int
(
re
.
findall
(
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
'[0-9]+'
,
self
.
async_step_counter
.
__str__
())[
0
])
self
.
add_scalar_summary
(
self
.
monitors
.
put
(
'async_global_step'
,
async_step_total_cnt
)
'async_global_step'
,
async_step_total_cnt
)
except
:
except
:
logger
.
exception
(
"Cannot log async_global_step"
)
logger
.
exception
(
"Cannot log async_global_step"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment