Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ef9e27a6
Commit
ef9e27a6
authored
Jul 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dynamically generate the remapped class
parent
adf51f22
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
45 additions
and
26 deletions
+45
-26
tensorpack/train/input_source.py
tensorpack/train/input_source.py
+45
-26
No files found.
tensorpack/train/input_source.py
View file @
ef9e27a6
...
@@ -20,7 +20,6 @@ from ..tfutils.summary import add_moving_summary
...
@@ -20,7 +20,6 @@ from ..tfutils.summary import add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..tfutils
import
get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..utils.concurrency
import
ShareSessionThread
from
..utils.concurrency
import
ShareSessionThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.base
import
Callback
from
..callbacks.base
import
Callback
...
@@ -62,6 +61,9 @@ class InputSource(object):
...
@@ -62,6 +61,9 @@ class InputSource(object):
@
abstractmethod
@
abstractmethod
def
reset_state
(
self
):
def
reset_state
(
self
):
"""
Semantics of this method has not been well defined.
"""
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -80,6 +82,33 @@ class InputSource(object):
...
@@ -80,6 +82,33 @@ class InputSource(object):
return
NotImplementedError
()
return
NotImplementedError
()
class
ProxyInputSource
(
InputSource
):
"""
An InputSource which proxy every method to ``self._input``.
"""
def
__init__
(
self
,
input
):
assert
isinstance
(
input
,
InputSource
),
input
self
.
_input
=
input
def
get_input_tensors
(
self
):
return
self
.
_input
.
get_input_tensors
()
def
setup
(
self
,
inputs_desc
):
self
.
_input
.
setup
(
inputs_desc
)
def
get_callbacks
(
self
):
return
self
.
_input
.
get_callbacks
()
def
size
(
self
):
return
self
.
_input
.
size
()
def
next_feed
(
self
):
return
self
.
_input
.
next_feed
()
def
reset_state
(
self
):
self
.
_input
.
reset_state
()
class
FeedInput
(
InputSource
):
class
FeedInput
(
InputSource
):
""" Input by iterating over a DataFlow and feed datapoints. """
""" Input by iterating over a DataFlow and feed datapoints. """
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
...
@@ -465,57 +494,47 @@ class StagingInputWrapper(FeedfreeInput):
...
@@ -465,57 +494,47 @@ class StagingInputWrapper(FeedfreeInput):
def
get_staging_name
(
idx
):
def
get_staging_name
(
idx
):
return
'StagingArea{}'
.
format
(
idx
)
return
'StagingArea{}'
.
format
(
idx
)
@
memoized
def
get_stage_op
(
self
):
def
get_stage_op
(
self
):
return
tf
.
group
(
*
self
.
_stage_ops
)
return
tf
.
group
(
*
self
.
_stage_ops
)
@
memoized
def
get_unstage_op
(
self
):
def
get_unstage_op
(
self
):
all_outputs
=
list
(
chain
.
from_iterable
(
self
.
_unstage_ops
))
all_outputs
=
list
(
chain
.
from_iterable
(
self
.
_unstage_ops
))
return
tf
.
group
(
*
all_outputs
)
return
tf
.
group
(
*
all_outputs
)
# TODO dynamically generate inheritance
def
remap_input_source
(
input
,
names
):
# TODO make it a function, not a class
class
remap_input_source
(
FeedInput
,
FeedfreeInput
):
"""
"""
When you have some :class:`InputSource` which doesn't match the inputs in
When you have some :class:`InputSource` which doesn't match the inputs in
your :class:`ModelDesc`, use `RemapInputSource`.
your :class:`ModelDesc`, use `RemapInputSource`.
It produces placeholders for all the inputs in your model,
It produces placeholders for all the inputs in your model,
except that the corresponding ones are replaced with the tensor produced
except that the corresponding ones are replaced with the tensor produced
by the given :class:`InputSource`.
by the given :class:`InputSource`.
"""
def
__init__
(
self
,
input
,
names
):
"""
Args:
Args:
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
input(InputSource): a :class:`InputSource`, whose tensors will get mapped.
names(list[str]): list of input names corresponding to the tensors
names(list[str]): list of input names corresponding to the tensors
produced by ``input``.
produced by ``input``.
"""
"""
assert
isinstance
(
input
,
InputSource
),
input
def
__init__
(
self
,
input
,
names
):
self
.
_input
=
input
ProxyInputSource
.
__init__
(
self
,
input
)
assert
isinstance
(
names
,
(
list
,
tuple
)),
names
assert
isinstance
(
names
,
(
list
,
tuple
)),
names
self
.
_names
=
tuple
(
names
)
self
.
_names
=
tuple
(
names
)
def
size
(
self
):
return
self
.
_input
.
size
()
def
setup
(
self
,
inputs
):
def
setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
()
for
v
in
inputs
]
inputs_subset
=
get_sublist_by_names
(
inputs
,
self
.
_names
)
inputs_subset
=
get_sublist_by_names
(
inputs
,
self
.
_names
)
self
.
_input
.
setup
(
inputs_subset
)
self
.
_input
.
setup
(
inputs_subset
)
def
get_callbacks
(
self
):
return
self
.
_input
.
get_callbacks
()
def
reset_state
(
self
):
self
.
_input
.
reset_state
()
def
next_feed
(
self
):
return
self
.
_input
.
next_feed
()
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
ret
=
self
.
_input
.
get_input_tensors
()
ret
=
self
.
_input
.
get_input_tensors
()
assert
len
(
ret
)
==
len
(
self
.
_names
)
assert
len
(
ret
)
==
len
(
self
.
_names
)
return
get_tensors_inputs
(
return
get_tensors_inputs
(
self
.
_all_placehdrs
,
ret
,
self
.
_names
)
self
.
_all_placehdrs
,
ret
,
self
.
_names
)
oldcls
=
type
(
input
)
# inherit oldcls so that type check in various places would work
cls
=
type
(
'Remapped'
+
oldcls
.
__name__
,
(
ProxyInputSource
,
oldcls
),
{
'__init__'
:
__init__
,
'setup'
:
setup
,
'get_input_tensors'
:
get_input_tensors
})
return
cls
(
input
,
names
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment