Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
aa7e18fc
Commit
aa7e18fc
authored
Mar 26, 2020
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
sessinit
parent
acb441ca
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
14 deletions
+19
-14
scripts/dump-model-params.py
scripts/dump-model-params.py
+4
-0
tensorpack/callbacks/monitor.py
tensorpack/callbacks/monitor.py
+7
-5
tensorpack/train/config.py
tensorpack/train/config.py
+8
-9
No files found.
scripts/dump-model-params.py
View file @
aa7e18fc
...
@@ -37,6 +37,10 @@ def _import_external_ops(message):
...
@@ -37,6 +37,10 @@ def _import_external_ops(message):
else
:
else
:
from
tensorflow.python.ops
import
gen_nccl_ops
# noqa
from
tensorflow.python.ops
import
gen_nccl_ops
# noqa
return
return
if
'ZMQConnection'
in
message
:
import
zmq_ops
return
logger
.
error
(
"Unhandled error: "
+
message
)
def
guess_inputs
(
input_dir
):
def
guess_inputs
(
input_dir
):
...
...
tensorpack/callbacks/monitor.py
View file @
aa7e18fc
...
@@ -306,11 +306,13 @@ class JSONWriter(MonitorBase):
...
@@ -306,11 +306,13 @@ class JSONWriter(MonitorBase):
return
NoOpMonitor
(
"JSONWriter"
)
return
NoOpMonitor
(
"JSONWriter"
)
@
staticmethod
@
staticmethod
def
load_existing_json
():
def
load_existing_json
(
dir
=
None
):
"""
"""
Look for an existing json under :meth:`logger.get_logger_dir()` named "stats.json",
Look for an existing json under dir (defaults to
:meth:`logger.get_logger_dir()`) named "stats.json",
and return the loaded list of statistics if found. Returns None otherwise.
and return the loaded list of statistics if found. Returns None otherwise.
"""
"""
if
dir
is
None
:
dir
=
logger
.
get_logger_dir
()
dir
=
logger
.
get_logger_dir
()
fname
=
os
.
path
.
join
(
dir
,
JSONWriter
.
FILENAME
)
fname
=
os
.
path
.
join
(
dir
,
JSONWriter
.
FILENAME
)
if
tf
.
gfile
.
Exists
(
fname
):
if
tf
.
gfile
.
Exists
(
fname
):
...
@@ -321,12 +323,12 @@ class JSONWriter(MonitorBase):
...
@@ -321,12 +323,12 @@ class JSONWriter(MonitorBase):
return
None
return
None
@
staticmethod
@
staticmethod
def
load_existing_epoch_number
():
def
load_existing_epoch_number
(
dir
=
None
):
"""
"""
Try to load the latest epoch number from an existing json stats file (if any).
Try to load the latest epoch number from an existing json stats file (if any).
Returns None if not found.
Returns None if not found.
"""
"""
stats
=
JSONWriter
.
load_existing_json
()
stats
=
JSONWriter
.
load_existing_json
(
dir
)
try
:
try
:
return
int
(
stats
[
-
1
][
'epoch_num'
])
return
int
(
stats
[
-
1
][
'epoch_num'
])
except
Exception
:
except
Exception
:
...
...
tensorpack/train/config.py
View file @
aa7e18fc
...
@@ -206,7 +206,7 @@ class AutoResumeTrainConfig(TrainConfig):
...
@@ -206,7 +206,7 @@ class AutoResumeTrainConfig(TrainConfig):
"""
"""
found_sessinit
=
False
found_sessinit
=
False
if
always_resume
or
'session_init'
not
in
kwargs
:
if
always_resume
or
'session_init'
not
in
kwargs
:
sessinit
=
self
.
_
get_sessinit_resume
()
sessinit
=
self
.
get_sessinit_resume
()
if
sessinit
is
not
None
:
if
sessinit
is
not
None
:
found_sessinit
=
True
found_sessinit
=
True
path
=
sessinit
.
path
path
=
sessinit
.
path
...
@@ -219,7 +219,7 @@ class AutoResumeTrainConfig(TrainConfig):
...
@@ -219,7 +219,7 @@ class AutoResumeTrainConfig(TrainConfig):
found_last_epoch
=
False
found_last_epoch
=
False
if
always_resume
or
'starting_epoch'
not
in
kwargs
:
if
always_resume
or
'starting_epoch'
not
in
kwargs
:
last_epoch
=
self
.
_get_last_epoch
()
last_epoch
=
JSONWriter
.
load_existing_epoch_number
()
if
last_epoch
is
not
None
:
if
last_epoch
is
not
None
:
found_last_epoch
=
True
found_last_epoch
=
True
now_epoch
=
last_epoch
+
1
now_epoch
=
last_epoch
+
1
...
@@ -231,14 +231,13 @@ class AutoResumeTrainConfig(TrainConfig):
...
@@ -231,14 +231,13 @@ class AutoResumeTrainConfig(TrainConfig):
super
(
AutoResumeTrainConfig
,
self
)
.
__init__
(
**
kwargs
)
super
(
AutoResumeTrainConfig
,
self
)
.
__init__
(
**
kwargs
)
def
_get_sessinit_resume
(
self
):
@
staticmethod
logdir
=
logger
.
get_logger_dir
()
def
get_sessinit_resume
(
dir
=
None
):
if
not
logdir
:
if
dir
is
None
:
dir
=
logger
.
get_logger_dir
()
if
not
dir
:
return
None
return
None
path
=
os
.
path
.
join
(
log
dir
,
'checkpoint'
)
path
=
os
.
path
.
join
(
dir
,
'checkpoint'
)
if
not
tf
.
gfile
.
Exists
(
path
):
if
not
tf
.
gfile
.
Exists
(
path
):
return
None
return
None
return
SaverRestore
(
path
)
return
SaverRestore
(
path
)
def
_get_last_epoch
(
self
):
return
JSONWriter
.
load_existing_epoch_number
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment