Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b7766fc1
Commit
b7766fc1
authored
Apr 22, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix swig name. add some df
parent
3efce3ae
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
17 deletions
+46
-17
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+42
-14
tensorpack/predict.py
tensorpack/predict.py
+2
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+1
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+1
-1
No files found.
tensorpack/dataflow/common.py
View file @
b7766fc1
...
...
@@ -9,7 +9,8 @@ from .base import DataFlow, ProxyDataFlow
from
..utils
import
*
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'FakeData'
,
'MapData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'RandomMixData'
,
'JoinData'
]
'MapDataComponent'
,
'RandomChooseData'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
]
class
BatchData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
...
...
@@ -249,7 +250,7 @@ class RandomMixData(DataFlow):
for
k
in
idxs
:
yield
next
(
itrs
[
k
])
class
Join
Data
(
DataFlow
):
class
Concat
Data
(
DataFlow
):
"""
Concatenate several dataflows.
"""
...
...
@@ -271,21 +272,48 @@ class JoinData(DataFlow):
for
dp
in
d
.
get_data
():
yield
dp
class
SelectComponent
(
Proxy
DataFlow
):
class
JoinData
(
DataFlow
):
"""
Select component from a datapoint.
Join the components from each DataFlow.
e.g.: df1: [dp1, dp2]
df2: [dp3, dp4]
join: [dp1, dp2, dp3, dp4]
"""
def
__init__
(
self
,
ds
,
idxs
):
def
__init__
(
self
,
df_lists
):
"""
:param df_lists: list of :mod:`DataFlow` instances
"""
self
.
df_lists
=
df_lists
self
.
_size
=
self
.
df_lists
[
0
]
.
size
()
for
d
in
self
.
df_lists
:
assert
d
.
size
()
==
self
.
_size
,
\
"All DataFlow must have the same size! {} != {}"
.
format
(
d
.
size
(),
self
.
_size
)
def
reset_state
(
self
):
for
d
in
self
.
df_lists
:
d
.
reset_state
()
def
size
(
self
):
return
self
.
_size
def
get_data
(
self
):
itrs
=
[
k
.
get_data
()
for
k
in
self
.
df_lists
]
try
:
while
True
:
dp
=
[]
for
itr
in
itrs
:
dp
.
extend
(
next
(
itr
))
yield
dp
except
StopIteration
:
pass
finally
:
for
itr
in
itrs
:
del
itr
def
SelectComponent
(
ds
,
idxs
):
"""
:param ds: a :mod:`DataFlow` instance
:param idxs: a list of datapoint component index of the original dataflow
"""
super
(
SelectComponent
,
self
)
.
__init__
(
ds
)
self
.
idxs
=
idxs
return
MapData
(
ds
,
lambda
dp
:
[
dp
[
i
]
for
i
in
idxs
])
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
newdp
=
[]
for
idx
in
self
.
idxs
:
newdp
.
append
(
dp
[
idx
])
yield
newdp
tensorpack/predict.py
View file @
b7766fc1
...
...
@@ -79,7 +79,8 @@ def get_predict_func(config):
# check output_var_names against output_vars
if
output_var_names
is
not
None
:
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
n
)
for
n
in
output_var_names
]
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
for
n
in
output_var_names
]
else
:
output_vars
=
[]
...
...
tensorpack/tfutils/common.py
View file @
b7766fc1
...
...
@@ -6,7 +6,7 @@
from
..utils.naming
import
*
import
tensorflow
as
tf
def
get_default_sess_config
(
mem_fraction
=
0.
5
):
def
get_default_sess_config
(
mem_fraction
=
0.
99
):
"""
Return a better session config to use as default.
Tensorflow default session config consume too much resources.
...
...
tensorpack/tfutils/sessinit.py
View file @
b7766fc1
...
...
@@ -85,7 +85,7 @@ class SaverRestore(SessionInit):
@
staticmethod
def
_read_checkpoint_vars
(
model_path
):
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
return
set
(
reader
.
GetVariableToShapeM
ap
()
.
keys
())
return
set
(
reader
.
get_variable_to_shape_m
ap
()
.
keys
())
@
staticmethod
def
_get_vars_to_restore_multimap
(
vars_available
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment