Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
897d29e3
Commit
897d29e3
authored
Oct 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
callback.get_tensor_maybe_in_tower
parent
395786db
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
3 deletions
+36
-3
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+21
-0
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+3
-3
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+12
-0
No files found.
tensorpack/callbacks/base.py
View file @
897d29e3
...
@@ -7,6 +7,7 @@ from abc import ABCMeta
...
@@ -7,6 +7,7 @@ from abc import ABCMeta
import
six
import
six
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
from
..tfutils.common
import
get_op_or_tensor_by_name
from
..tfutils.common
import
get_op_or_tensor_by_name
from
..train.tower
import
TowerTrainer
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
]
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
]
...
@@ -205,6 +206,26 @@ class Callback(object):
...
@@ -205,6 +206,26 @@ class Callback(object):
def
__str__
(
self
):
def
__str__
(
self
):
return
type
(
self
)
.
__name__
return
type
(
self
)
.
__name__
def
get_tensors_maybe_in_tower
(
self
,
names
):
"""
Get tensors in the graph.
Will automatically check for the __first training tower__
if no tensor with the given name exists.
"""
def
get_tensor
(
name
):
msg
=
"Tensor {} not found in the graph!"
.
format
(
name
)
try
:
return
get_op_or_tensor_by_name
(
name
)
except
KeyError
:
pass
assert
isinstance
(
self
.
trainer
,
TowerTrainer
),
msg
towers
=
self
.
trainer
.
tower_func
.
towers
try
:
return
towers
.
training
()[
name
]
except
KeyError
:
raise
KeyError
(
msg
)
return
[
get_tensor
(
name
)
for
name
in
names
]
class
ProxyCallback
(
Callback
):
class
ProxyCallback
(
Callback
):
""" A callback which proxy all methods to another callback.
""" A callback which proxy all methods to another callback.
...
...
tensorpack/callbacks/steps.py
View file @
897d29e3
...
@@ -12,7 +12,7 @@ from ..utils import logger
...
@@ -12,7 +12,7 @@ from ..utils import logger
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..tfutils.common
import
(
from
..tfutils.common
import
(
get_op_tensor_name
,
get_
op_or_tensor_by_name
,
get_
global_step_var
)
get_op_tensor_name
,
get_global_step_var
)
from
.base
import
Callback
from
.base
import
Callback
__all__
=
[
'TensorPrinter'
,
'StepTensorPrinter'
,
'ProgressBar'
]
__all__
=
[
'TensorPrinter'
,
'StepTensorPrinter'
,
'ProgressBar'
]
...
@@ -33,7 +33,7 @@ class TensorPrinter(Callback):
...
@@ -33,7 +33,7 @@ class TensorPrinter(Callback):
self
.
_names
=
names
self
.
_names
=
names
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
self
.
_fetches
=
self
.
get_tensors_maybe_in_tower
(
self
.
_names
)
def
_before_run
(
self
,
_
):
def
_before_run
(
self
,
_
):
return
self
.
_fetches
return
self
.
_fetches
...
@@ -70,7 +70,7 @@ class ProgressBar(Callback):
...
@@ -70,7 +70,7 @@ class ProgressBar(Callback):
self
.
_total
=
self
.
trainer
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
steps_per_epoch
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
self
.
_fetches
=
get_op_or_tensor_by_name
(
self
.
_names
)
or
None
self
.
_fetches
=
self
.
get_tensors_maybe_in_tower
(
self
.
_names
)
or
None
if
self
.
_fetches
:
if
self
.
_fetches
:
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
_fetches
)
self
.
_fetches
=
tf
.
train
.
SessionRunArgs
(
self
.
_fetches
)
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
...
...
tensorpack/tfutils/tower.py
View file @
897d29e3
...
@@ -226,6 +226,14 @@ class TowerTensorHandles(object):
...
@@ -226,6 +226,14 @@ class TowerTensorHandles(object):
return
self
.
_handles
[
name_or_index
]
return
self
.
_handles
[
name_or_index
]
return
self
.
_name_to_handle
[
name_or_index
]
return
self
.
_name_to_handle
[
name_or_index
]
def
training
(
self
):
"""
Returns:
Still a :class:`TowerTensorHandles`, containing only the training towers.
"""
handles
=
[
h
for
h
in
self
.
_handles
if
h
.
is_training
]
return
TowerTensorHandles
(
handles
)
class
TowerTensorHandle
(
object
):
class
TowerTensorHandle
(
object
):
"""
"""
...
@@ -315,3 +323,7 @@ class TowerTensorHandle(object):
...
@@ -315,3 +323,7 @@ class TowerTensorHandle(object):
The output returned by the tower function.
The output returned by the tower function.
"""
"""
return
self
.
_output
return
self
.
_output
@
property
def
is_training
(
self
):
return
self
.
_ctx
.
is_training
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment