Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
35beb43c
Commit
35beb43c
authored
Dec 28, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add comet.ml monitor
parent
a4d4eafc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
76 additions
and
16 deletions
+76
-16
docs/conf.py
docs/conf.py
+2
-1
tensorpack/callbacks/monitor.py
tensorpack/callbacks/monitor.py
+70
-11
tensorpack/train/base.py
tensorpack/train/base.py
+3
-3
tensorpack/train/config.py
tensorpack/train/config.py
+1
-1
No files found.
docs/conf.py
View file @
35beb43c
...
@@ -380,13 +380,14 @@ _DEPRECATED_NAMES = set([
...
@@ -380,13 +380,14 @@ _DEPRECATED_NAMES = set([
'dump_dataflow_to_process_queue'
,
'dump_dataflow_to_process_queue'
,
'PrefetchOnGPUs'
,
'PrefetchOnGPUs'
,
# renamed
stuff:
# renamed
items that should not appear in docs
'DumpTensor'
,
'DumpTensor'
,
'DumpParamAsImage'
,
'DumpParamAsImage'
,
'PeriodicRunHooks'
,
'PeriodicRunHooks'
,
'get_nr_gpu'
,
'get_nr_gpu'
,
'start_test'
,
# TestDataSpeed
'start_test'
,
# TestDataSpeed
'ThreadedMapData'
,
'ThreadedMapData'
,
'TrainingMonitor'
,
# deprecated or renamed symbolic code
# deprecated or renamed symbolic code
'BilinearUpSample'
,
'BilinearUpSample'
,
...
...
tensorpack/callbacks/monitor.py
View file @
35beb43c
...
@@ -14,14 +14,16 @@ from datetime import datetime
...
@@ -14,14 +14,16 @@ from datetime import datetime
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..libinfo
import
__git_version__
from
..tfutils.summary
import
create_image_summary
,
create_scalar_summary
from
..tfutils.summary
import
create_image_summary
,
create_scalar_summary
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
HIDE_DOC
from
..utils.develop
import
HIDE_DOC
from
.base
import
Callback
from
.base
import
Callback
__all__
=
[
'
TrainingMonitor
'
,
'Monitors'
,
__all__
=
[
'
MonitorBase
'
,
'Monitors'
,
'TFEventWriter'
,
'JSONWriter'
,
'TFEventWriter'
,
'JSONWriter'
,
'ScalarPrinter'
,
'SendMonitorData'
]
'ScalarPrinter'
,
'SendMonitorData'
,
'TrainingMonitor'
,
'CometMLMonitor'
]
def
image_to_nhwc
(
arr
):
def
image_to_nhwc
(
arr
):
...
@@ -39,9 +41,9 @@ def image_to_nhwc(arr):
...
@@ -39,9 +41,9 @@ def image_to_nhwc(arr):
return
arr
return
arr
class
TrainingMonitor
(
Callback
):
class
MonitorBase
(
Callback
):
"""
"""
M
onitor a training progress, by processing different types of
Base class for monitors which m
onitor a training progress, by processing different types of
summary/statistics from trainer.
summary/statistics from trainer.
.. document private functions
.. document private functions
...
@@ -95,7 +97,13 @@ class TrainingMonitor(Callback):
...
@@ -95,7 +97,13 @@ class TrainingMonitor(Callback):
# TODO process other types
# TODO process other types
class
NoOpMonitor
(
TrainingMonitor
):
TrainingMonitor
=
MonitorBase
"""
Old name
"""
class
NoOpMonitor
(
MonitorBase
):
def
__init__
(
self
,
name
=
None
):
def
__init__
(
self
,
name
=
None
):
self
.
_name
=
name
self
.
_name
=
name
...
@@ -121,7 +129,7 @@ class Monitors(Callback):
...
@@ -121,7 +129,7 @@ class Monitors(Callback):
self
.
_scalar_history
=
ScalarHistory
()
self
.
_scalar_history
=
ScalarHistory
()
self
.
_monitors
=
monitors
+
[
self
.
_scalar_history
]
self
.
_monitors
=
monitors
+
[
self
.
_scalar_history
]
for
m
in
self
.
_monitors
:
for
m
in
self
.
_monitors
:
assert
isinstance
(
m
,
TrainingMonitor
),
m
assert
isinstance
(
m
,
MonitorBase
),
m
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
# scalar_history's other methods were not called.
# scalar_history's other methods were not called.
...
@@ -219,7 +227,7 @@ class Monitors(Callback):
...
@@ -219,7 +227,7 @@ class Monitors(Callback):
return
self
.
_scalar_history
.
get_history
(
name
)
return
self
.
_scalar_history
.
get_history
(
name
)
class
TFEventWriter
(
TrainingMonitor
):
class
TFEventWriter
(
MonitorBase
):
"""
"""
Write summaries to TensorFlow event file.
Write summaries to TensorFlow event file.
"""
"""
...
@@ -272,7 +280,7 @@ class TFEventWriter(TrainingMonitor):
...
@@ -272,7 +280,7 @@ class TFEventWriter(TrainingMonitor):
self
.
_writer
.
close
()
self
.
_writer
.
close
()
class
JSONWriter
(
TrainingMonitor
):
class
JSONWriter
(
MonitorBase
):
"""
"""
Write all scalar data to a json file under ``logger.get_logger_dir()``, grouped by their global step.
Write all scalar data to a json file under ``logger.get_logger_dir()``, grouped by their global step.
If found an earlier json history file, will append to it.
If found an earlier json history file, will append to it.
...
@@ -390,7 +398,7 @@ class JSONWriter(TrainingMonitor):
...
@@ -390,7 +398,7 @@ class JSONWriter(TrainingMonitor):
logger
.
exception
(
"Exception in JSONWriter._write_stat()!"
)
logger
.
exception
(
"Exception in JSONWriter._write_stat()!"
)
class
ScalarPrinter
(
TrainingMonitor
):
class
ScalarPrinter
(
MonitorBase
):
"""
"""
Print scalar data into terminal.
Print scalar data into terminal.
"""
"""
...
@@ -460,7 +468,7 @@ class ScalarPrinter(TrainingMonitor):
...
@@ -460,7 +468,7 @@ class ScalarPrinter(TrainingMonitor):
self
.
_dic
=
{}
self
.
_dic
=
{}
class
ScalarHistory
(
TrainingMonitor
):
class
ScalarHistory
(
MonitorBase
):
"""
"""
Only internally used by monitors.
Only internally used by monitors.
"""
"""
...
@@ -483,7 +491,7 @@ class ScalarHistory(TrainingMonitor):
...
@@ -483,7 +491,7 @@ class ScalarHistory(TrainingMonitor):
return
self
.
_dic
[
name
]
return
self
.
_dic
[
name
]
class
SendMonitorData
(
TrainingMonitor
):
class
SendMonitorData
(
MonitorBase
):
"""
"""
Execute a command with some specific scalar monitor data.
Execute a command with some specific scalar monitor data.
This is useful for, e.g. building a custom statistics monitor.
This is useful for, e.g. building a custom statistics monitor.
...
@@ -531,3 +539,54 @@ class SendMonitorData(TrainingMonitor):
...
@@ -531,3 +539,54 @@ class SendMonitorData(TrainingMonitor):
if
ret
!=
0
:
if
ret
!=
0
:
logger
.
error
(
"Command '{}' failed with ret={}!"
.
format
(
cmd
,
ret
))
logger
.
error
(
"Command '{}' failed with ret={}!"
.
format
(
cmd
,
ret
))
self
.
dic
=
{}
self
.
dic
=
{}
class
CometMLMonitor
(
MonitorBase
):
"""
Send data to https://www.comet.ml.
Note:
1. comet_ml requires you to `import comet_ml` before importing tensorflow or tensorpack.
2. The "automatic output logging" feature will make the training progress bar appear to freeze.
Therefore the feature is disabled by default.
"""
def
__init__
(
self
,
experiment
=
None
,
api_key
=
None
,
tags
=
None
,
**
kwargs
):
"""
Args:
experiment (comet_ml.Experiment): if provided, invalidate all other arguments
api_key (str): your comet.ml API key
tags (list[str]): experiment tags
kwargs: other arguments passed to :class:`comet_ml.Experiment`.
"""
if
experiment
is
not
None
:
self
.
_exp
=
experiment
assert
api_key
is
None
and
tags
is
None
and
len
(
kwargs
)
==
0
else
:
from
comet_ml
import
Experiment
kwargs
.
setdefault
(
'log_code'
,
True
)
# though it's not functioning, git patch logging requires it
kwargs
.
setdefault
(
'auto_output_logging'
,
None
)
self
.
_exp
=
Experiment
(
api_key
=
api_key
,
**
kwargs
)
if
tags
is
not
None
:
self
.
_exp
.
add_tags
(
tags
)
self
.
_exp
.
set_code
(
"Code logging is impossible because there are too many files ..."
)
self
.
_exp
.
log_dependency
(
'tensorpack'
,
__git_version__
)
@
property
def
experiment
(
self
):
"""
Returns: the :class:`comet_ml.Experiment` instance.
"""
return
self
.
_exp
def
_before_train
(
self
):
self
.
_exp
.
set_model_graph
(
tf
.
get_default_graph
())
def
process_scalar
(
self
,
name
,
val
):
self
.
_exp
.
log_metric
(
name
,
val
,
step
=
self
.
global_step
)
def
_after_train
(
self
):
self
.
_exp
.
end
()
def
_after_epoch
(
self
):
self
.
_exp
.
log_epoch_end
(
self
.
epoch_num
)
tensorpack/train/base.py
View file @
35beb43c
...
@@ -8,7 +8,7 @@ import six
...
@@ -8,7 +8,7 @@ import six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
six.moves
import
range
from
six.moves
import
range
from
..callbacks
import
Callback
,
Callbacks
,
Monitors
,
TrainingMonitor
from
..callbacks
import
Callback
,
Callbacks
,
Monitors
,
MonitorBase
from
..callbacks.steps
import
MaintainStepCounter
from
..callbacks.steps
import
MaintainStepCounter
from
..tfutils
import
get_global_step_value
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.model_utils
import
describe_trainable_vars
...
@@ -186,7 +186,7 @@ class Trainer(object):
...
@@ -186,7 +186,7 @@ class Trainer(object):
Args:
Args:
callbacks ([Callback]):
callbacks ([Callback]):
monitors ([
TrainingMonitor
]):
monitors ([
MonitorBase
]):
"""
"""
assert
isinstance
(
callbacks
,
list
),
callbacks
assert
isinstance
(
callbacks
,
list
),
callbacks
assert
isinstance
(
monitors
,
list
),
monitors
assert
isinstance
(
monitors
,
list
),
monitors
...
@@ -196,7 +196,7 @@ class Trainer(object):
...
@@ -196,7 +196,7 @@ class Trainer(object):
for
cb
in
callbacks
:
for
cb
in
callbacks
:
self
.
register_callback
(
cb
)
self
.
register_callback
(
cb
)
for
cb
in
self
.
_callbacks
:
for
cb
in
self
.
_callbacks
:
assert
not
isinstance
(
cb
,
TrainingMonitor
),
"Monitor cannot be pre-registered for now!"
assert
not
isinstance
(
cb
,
MonitorBase
),
"Monitor cannot be pre-registered for now!"
registered_monitors
=
[]
registered_monitors
=
[]
for
m
in
monitors
:
for
m
in
monitors
:
if
self
.
register_callback
(
m
):
if
self
.
register_callback
(
m
):
...
...
tensorpack/train/config.py
View file @
35beb43c
...
@@ -84,7 +84,7 @@ class TrainConfig(object):
...
@@ -84,7 +84,7 @@ class TrainConfig(object):
MergeAllSummaries(),
MergeAllSummaries(),
RunUpdateOps()]
RunUpdateOps()]
monitors (list[
TrainingMonitor
]): Defaults to :func:`DEFAULT_MONITORS()`.
monitors (list[
MonitorBase
]): Defaults to :func:`DEFAULT_MONITORS()`.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`.
with the config returned by :func:`tfutils.get_default_sess_config()`.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment