Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2ba9c3cd
Commit
2ba9c3cd
authored
Apr 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
auto load epoch number from JSON. (#171)
parent
7782e724
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
4 deletions
+14
-4
tensorpack/callbacks/monitor.py
tensorpack/callbacks/monitor.py
+11
-1
tensorpack/train/config.py
tensorpack/train/config.py
+3
-3
No files found.
tensorpack/callbacks/monitor.py
View file @
2ba9c3cd
...
@@ -143,6 +143,8 @@ class JSONWriter(TrainingMonitor):
...
@@ -143,6 +143,8 @@ class JSONWriter(TrainingMonitor):
"""
"""
Write all scalar data to a json, grouped by their global step.
Write all scalar data to a json, grouped by their global step.
"""
"""
FILENAME
=
'stat.json'
def
__new__
(
cls
):
def
__new__
(
cls
):
if
logger
.
LOG_DIR
:
if
logger
.
LOG_DIR
:
return
super
(
JSONWriter
,
cls
)
.
__new__
(
cls
)
return
super
(
JSONWriter
,
cls
)
.
__new__
(
cls
)
...
@@ -152,7 +154,7 @@ class JSONWriter(TrainingMonitor):
...
@@ -152,7 +154,7 @@ class JSONWriter(TrainingMonitor):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_dir
=
logger
.
LOG_DIR
self
.
_dir
=
logger
.
LOG_DIR
self
.
_fname
=
os
.
path
.
join
(
self
.
_dir
,
'stat.json'
)
self
.
_fname
=
os
.
path
.
join
(
self
.
_dir
,
self
.
FILENAME
)
if
os
.
path
.
isfile
(
self
.
_fname
):
if
os
.
path
.
isfile
(
self
.
_fname
):
# TODO make a backup first?
# TODO make a backup first?
...
@@ -160,6 +162,14 @@ class JSONWriter(TrainingMonitor):
...
@@ -160,6 +162,14 @@ class JSONWriter(TrainingMonitor):
with
open
(
self
.
_fname
)
as
f
:
with
open
(
self
.
_fname
)
as
f
:
self
.
_stats
=
json
.
load
(
f
)
self
.
_stats
=
json
.
load
(
f
)
assert
isinstance
(
self
.
_stats
,
list
),
type
(
self
.
_stats
)
assert
isinstance
(
self
.
_stats
,
list
),
type
(
self
.
_stats
)
try
:
epoch
=
self
.
_stats
[
-
1
][
'epoch_num'
]
+
1
except
Exception
:
pass
else
:
logger
.
info
(
"Found training history from JSON, now starting from epoch number {}."
.
format
(
epoch
))
self
.
trainer
.
config
.
starting_epoch
=
epoch
else
:
else
:
self
.
_stats
=
[]
self
.
_stats
=
[]
self
.
_stat_now
=
{}
self
.
_stat_now
=
{}
...
...
tensorpack/train/config.py
View file @
2ba9c3cd
...
@@ -6,7 +6,8 @@ import tensorflow as tf
...
@@ -6,7 +6,8 @@ import tensorflow as tf
from
..callbacks
import
(
from
..callbacks
import
(
Callbacks
,
MovingAverageSummary
,
Callbacks
,
MovingAverageSummary
,
ProgressBar
,
MergeAllSummaries
)
ProgressBar
,
MergeAllSummaries
,
TFSummaryWriter
,
JSONWriter
,
ScalarPrinter
)
from
..dataflow.base
import
DataFlow
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..models
import
ModelDesc
from
..utils
import
logger
from
..utils
import
logger
...
@@ -16,7 +17,6 @@ from ..tfutils import (JustCurrentSession,
...
@@ -16,7 +17,6 @@ from ..tfutils import (JustCurrentSession,
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.optimizer
import
apply_grad_processors
from
..tfutils.optimizer
import
apply_grad_processors
from
.input_data
import
InputData
from
.input_data
import
InputData
from
..callbacks.monitor
import
TFSummaryWriter
,
JSONWriter
,
ScalarPrinter
__all__
=
[
'TrainConfig'
]
__all__
=
[
'TrainConfig'
]
...
@@ -44,7 +44,7 @@ class TrainConfig(object):
...
@@ -44,7 +44,7 @@ class TrainConfig(object):
callbacks (list): a list of :class:`Callback` to perform during training.
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
is only used to provide the defaults. The defaults are
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries()]``. The list of
``[MovingAverageSummary(), ProgressBar(), MergeAllSummaries()
, LoadEpochNum()
]``. The list of
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
callbacks that will be used in the end are ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
Defaults to ``[TFSummaryWriter(), JSONWriter(), ScalarPrinter()]``.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment