Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ef9e27a6
Commit
ef9e27a6
authored
Jul 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dynamically generate the remapped class
parent
adf51f22
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
26 deletions
+45
-26
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+45
-26
No files found.
tensorpack/train/input_source.py
View file @
ef9e27a6
...
...
@@ -20,7 +20,6 @@ from ..tfutils.summary import add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.concurrency
import
ShareSessionThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.base
import
Callback
...
...
@@ -62,6 +61,9 @@ class InputSource(object):
@
abstractmethod
def
reset_state
(
self
):
"""
Semantics of this method has not been well defined.
"""
pass
@
abstractmethod
...
...
@@ -80,6 +82,33 @@ class InputSource(object):
return
NotImplementedError
()
class
ProxyInputSource
(
InputSource
):
"""
An InputSource which proxy every method to ``self._input``.
"""
def
__init__
(
self
,
input
):
assert
isinstance
(
input
,
InputSource
),
input
self
.
_input
=
input
def
get_input_tensors
(
self
):
return
self
.
_input
.
get_input_tensors
()
def
setup
(
self
,
inputs_desc
):
self
.
_input
.
setup
(
inputs_desc
)
def
get_callbacks
(
self
):
return
self
.
_input
.
get_callbacks
()
def
size
(
self
):
return
self
.
_input
.
size
()
def
next_feed
(
self
):
return
self
.
_input
.
next_feed
()
def
reset_state
(
self
):
self
.
_input
.
reset_state
()
class
FeedInput
(
InputSource
):
""" Input by iterating over a DataFlow and feed datapoints. """
def
__init__
(
self
,
ds
):
...
...
@@ -465,57 +494,47 @@ class StagingInputWrapper(FeedfreeInput):
def
get_staging_name
(
idx
):
return
'StagingArea{}'
.
format
(
idx
)
@
memoized
def
get_stage_op
(
self
):
return
tf
.
group
(
*
self
.
_stage_ops
)
@
memoized
def
get_unstage_op
(
self
):
all_outputs
=
list
(
chain
.
from_iterable
(
self
.
_unstage_ops
))
return
tf
.
group
(
*
all_outputs
)
# TODO dynamically generate inheritance
# TODO make it a function, not a class
class
remap_input_source
(
FeedInput
,
FeedfreeInput
):
def
remap_input_source
(
input
,
names
):
"""
When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
def
__init__
(
self
,
input
,
names
):
"""
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
"""
assert
isinstance
(
input
,
InputSource
),
input
self
.
_input
=
input
ProxyInputSource
.
__init__
(
self
,
input
)
assert
isinstance
(
names
,
(
list
,
tuple
)),
names
self
.
_names
=
tuple
(
names
)
def
size
(
self
):
return
self
.
_input
.
size
()
def
setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
inputs_subset
=
get_sublist_by_names
(
inputs
,
self
.
_names
)
self
.
_input
.
setup
(
inputs_subset
)
def
get_callbacks
(
self
):
return
self
.
_input
.
get_callbacks
()
def
reset_state
(
self
):
self
.
_input
.
reset_state
()
def
next_feed
(
self
):
return
self
.
_input
.
next_feed
()
def
get_input_tensors
(
self
):
ret
=
self
.
_input
.
get_input_tensors
()
assert
len
(
ret
)
==
len
(
self
.
_names
)
return
get_tensors_inputs
(
self
.
_all_placehdrs
,
ret
,
self
.
_names
)
oldcls
=
type
(
input
)
# inherit oldcls so that type check in various places would work
cls
=
type
(
'Remapped'
+
oldcls
.
__name__
,
(
ProxyInputSource
,
oldcls
),
{
'__init__'
:
__init__
,
'setup'
:
setup
,
'get_input_tensors'
:
get_input_tensors
})
return
cls
(
input
,
names
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment