Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e68ea2a0
Commit
e68ea2a0
authored
Jan 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
some improvements in logging and fix #118.
parent
582cd482
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
32 deletions
+33
-32
tensorpack/__init__.py
tensorpack/__init__.py
+0
-3
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+0
-1
tensorpack/callbacks/concurrency.py
tensorpack/callbacks/concurrency.py
+1
-1
tensorpack/tfutils/modelutils.py
tensorpack/tfutils/modelutils.py
+3
-0
tensorpack/train/base.py
tensorpack/train/base.py
+29
-27
No files found.
tensorpack/__init__.py
View file @
e68ea2a0
...
@@ -15,6 +15,3 @@ from tensorpack.predict import *
...
@@ -15,6 +15,3 @@ from tensorpack.predict import *
if
int
(
numpy
.
__version__
.
split
(
'.'
)[
1
])
<
9
:
if
int
(
numpy
.
__version__
.
split
(
'.'
)[
1
])
<
9
:
logger
.
warn
(
"Numpy < 1.9 could be extremely slow on some tasks."
)
logger
.
warn
(
"Numpy < 1.9 could be extremely slow on some tasks."
)
if
get_tf_version
()
<
10
:
logger
.
error
(
"tensorpack requires TensorFlow >= 0.10"
)
tensorpack/callbacks/base.py
View file @
e68ea2a0
...
@@ -168,7 +168,6 @@ class PeriodicCallback(ProxyCallback):
...
@@ -168,7 +168,6 @@ class PeriodicCallback(ProxyCallback):
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
epoch_num
%
self
.
period
==
0
:
if
self
.
epoch_num
%
self
.
period
==
0
:
self
.
cb
.
epoch_num
=
self
.
epoch_num
-
1
self
.
cb
.
trigger_epoch
()
self
.
cb
.
trigger_epoch
()
def
__str__
(
self
):
def
__str__
(
self
):
...
...
tensorpack/callbacks/concurrency.py
View file @
e68ea2a0
...
@@ -27,6 +27,6 @@ class StartProcOrThread(Callback):
...
@@ -27,6 +27,6 @@ class StartProcOrThread(Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
logger
.
info
(
"Starting "
+
logger
.
info
(
"Starting "
+
', '
.
join
([
k
.
name
for
k
in
self
.
_procs_threads
]))
', '
.
join
([
k
.
name
for
k
in
self
.
_procs_threads
])
+
' ...'
)
# avoid sigint get handled by other processes
# avoid sigint get handled by other processes
start_proc_mask_signal
(
self
.
_procs_threads
)
start_proc_mask_signal
(
self
.
_procs_threads
)
tensorpack/tfutils/modelutils.py
View file @
e68ea2a0
...
@@ -13,6 +13,9 @@ __all__ = ['describe_model', 'get_shape_str']
...
@@ -13,6 +13,9 @@ __all__ = ['describe_model', 'get_shape_str']
def
describe_model
():
def
describe_model
():
""" Print a description of the current model parameters """
""" Print a description of the current model parameters """
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
if
len
(
train_vars
)
==
0
:
logger
.
info
(
"No trainable variables in the graph!"
)
return
msg
=
[
""
]
msg
=
[
""
]
total
=
0
total
=
0
for
v
in
train_vars
:
for
v
in
train_vars
:
...
...
tensorpack/train/base.py
View file @
e68ea2a0
...
@@ -68,25 +68,14 @@ class Trainer(object):
...
@@ -68,25 +68,14 @@ class Trainer(object):
def
run_step
(
self
):
def
run_step
(
self
):
""" Abstract method. Run one iteration. """
""" Abstract method. Run one iteration. """
def
get_
predict_func
(
self
,
input_names
,
output_names
):
def
get_
extra_fetches
(
self
):
"""
"""
Args:
input_names (list), output_names(list): list of names
Returns:
Returns:
an OnlinePredictor
list: list of tensors/ops to fetch in each step.
"""
raise
NotImplementedError
()
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
This function should only get called after :meth:`setup()` has finished.
""" Return n predictors.
Can be overwritten by subclasses to exploit more
parallelism among predictors.
"""
"""
if
len
(
self
.
config
.
predict_tower
)
>
1
:
return
self
.
_extra_fetches
logger
.
warn
(
"[Speed] Have set multiple predict_tower, but only have naive `get_predict_funcs` implementation"
)
return
[
self
.
get_predict_func
(
input_names
,
output_names
)
for
k
in
range
(
n
)]
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
"""
"""
...
@@ -129,28 +118,21 @@ class Trainer(object):
...
@@ -129,28 +118,21 @@ class Trainer(object):
"""
"""
self
.
add_summary
(
create_scalar_summary
(
name
,
val
))
self
.
add_summary
(
create_scalar_summary
(
name
,
val
))
def
get_extra_fetches
(
self
):
"""
Returns:
list: list of tensors/ops to fetch in each step.
This function should only get called after :meth:`setup()` has finished.
"""
return
self
.
_extra_fetches
def
setup
(
self
):
def
setup
(
self
):
"""
"""
Setup the trainer and be ready for the main loop.
Setup the trainer and be ready for the main loop.
"""
"""
self
.
_setup
()
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"logger directory wasn't set!"
)
self
.
_setup
()
# subclass will setup the graph
describe_model
()
describe_model
()
# some final operations that might modify the graph
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks ..."
)
logger
.
info
(
"Setup callbacks ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_extra_fetches
=
self
.
config
.
callbacks
.
extra_fetches
()
self
.
_extra_fetches
=
self
.
config
.
callbacks
.
extra_fetches
()
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"logger directory wasn't set!"
)
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
self
.
sess
.
graph
)
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
self
.
sess
.
graph
)
self
.
summary_op
=
tf
.
summary
.
merge_all
()
self
.
summary_op
=
tf
.
summary
.
merge_all
()
# create an empty StatHolder
# create an empty StatHolder
...
@@ -206,3 +188,23 @@ class Trainer(object):
...
@@ -206,3 +188,23 @@ class Trainer(object):
self
.
coord
.
request_stop
()
self
.
coord
.
request_stop
()
self
.
summary_writer
.
close
()
self
.
summary_writer
.
close
()
self
.
sess
.
close
()
self
.
sess
.
close
()
def
get_predict_func
(
self
,
input_names
,
output_names
):
"""
Args:
input_names (list), output_names(list): list of names
Returns:
an OnlinePredictor
"""
raise
NotImplementedError
()
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
""" Return n predictors.
Can be overwritten by subclasses to exploit more
parallelism among predictors.
"""
if
len
(
self
.
config
.
predict_tower
)
>
1
:
logger
.
warn
(
"[Speed] Have set multiple predict_tower, but only have naive `get_predict_funcs` implementation"
)
return
[
self
.
get_predict_func
(
input_names
,
output_names
)
for
k
in
range
(
n
)]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment