Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1126fa5c
Commit
1126fa5c
authored
Jul 07, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add rngdataflow as a base
parent
c6de1746
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
16 deletions
+42
-16
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+9
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+25
-11
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+8
-5
No files found.
tensorpack/dataflow/base.py
View file @
1126fa5c
...
@@ -35,6 +35,15 @@ class DataFlow(object):
...
@@ -35,6 +35,15 @@ class DataFlow(object):
"""
"""
pass
pass
class
RNGDataFlow
(
DataFlow
):
""" A dataflow with rng"""
def
__init__
(
self
):
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
class
ProxyDataFlow
(
DataFlow
):
class
ProxyDataFlow
(
DataFlow
):
""" Base class for DataFlow that proxies another"""
""" Base class for DataFlow that proxies another"""
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
...
...
tensorpack/dataflow/common.py
View file @
1126fa5c
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
from
__future__
import
division
from
__future__
import
division
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
from
collections
import
deque
from
six.moves
import
range
,
map
from
six.moves
import
range
,
map
from
.base
import
DataFlow
,
ProxyDataFlow
from
.base
import
DataFlow
,
ProxyDataFlow
from
..utils
import
*
from
..utils
import
*
...
@@ -12,7 +13,7 @@ from ..utils import *
...
@@ -12,7 +13,7 @@ from ..utils import *
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'FakeData'
,
'MapData'
,
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'FakeData'
,
'MapData'
,
'RepeatedData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'RepeatedData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'DataFromQueue'
]
'DataFromQueue'
,
'LocallyShuffleData'
]
class
BatchData
(
ProxyDataFlow
):
class
BatchData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
...
@@ -134,16 +135,16 @@ class RepeatedData(ProxyDataFlow):
...
@@ -134,16 +135,16 @@ class RepeatedData(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
yield
dp
yield
dp
class
FakeData
(
DataFlow
):
class
FakeData
(
RNG
DataFlow
):
""" Generate fake random data of given shapes"""
""" Generate fake random data of given shapes"""
def
__init__
(
self
,
shapes
,
size
):
def
__init__
(
self
,
shapes
,
size
):
"""
"""
:param shapes: a list of lists/tuples
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
:param size: size of this DataFlow
"""
"""
super
(
FakeData
,
self
)
.
__init__
()
self
.
shapes
=
shapes
self
.
shapes
=
shapes
self
.
_size
=
int
(
size
)
self
.
_size
=
int
(
size
)
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
def
size
(
self
):
return
self
.
_size
return
self
.
_size
...
@@ -191,7 +192,7 @@ class MapDataComponent(ProxyDataFlow):
...
@@ -191,7 +192,7 @@ class MapDataComponent(ProxyDataFlow):
dp
[
self
.
index
]
=
repl
# NOTE modifying
dp
[
self
.
index
]
=
repl
# NOTE modifying
yield
dp
yield
dp
class
RandomChooseData
(
DataFlow
):
class
RandomChooseData
(
RNG
DataFlow
):
"""
"""
Randomly choose from several DataFlow. Stop producing when any of them is
Randomly choose from several DataFlow. Stop producing when any of them is
exhausted.
exhausted.
...
@@ -200,21 +201,21 @@ class RandomChooseData(DataFlow):
...
@@ -200,21 +201,21 @@ class RandomChooseData(DataFlow):
"""
"""
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
"""
"""
super
(
RandomChooseData
,
self
)
.
__init__
()
if
isinstance
(
df_lists
[
0
],
(
tuple
,
list
)):
if
isinstance
(
df_lists
[
0
],
(
tuple
,
list
)):
assert
sum
([
v
[
1
]
for
v
in
df_lists
])
==
1.0
assert
sum
([
v
[
1
]
for
v
in
df_lists
])
==
1.0
self
.
df_lists
=
df_lists
self
.
df_lists
=
df_lists
else
:
else
:
prob
=
1.0
/
len
(
df_lists
)
prob
=
1.0
/
len
(
df_lists
)
self
.
df_lists
=
[(
k
,
prob
)
for
k
in
df_lists
]
self
.
df_lists
=
[(
k
,
prob
)
for
k
in
df_lists
]
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
RandomChooseData
,
self
)
.
reset_state
()
for
d
in
self
.
df_lists
:
for
d
in
self
.
df_lists
:
if
isinstance
(
d
,
tuple
):
if
isinstance
(
d
,
tuple
):
d
[
0
]
.
reset_state
()
d
[
0
]
.
reset_state
()
else
:
else
:
d
.
reset_state
()
d
.
reset_state
()
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
def
get_data
(
self
):
itrs
=
[
v
[
0
]
.
get_data
()
for
v
in
self
.
df_lists
]
itrs
=
[
v
[
0
]
.
get_data
()
for
v
in
self
.
df_lists
]
...
@@ -226,7 +227,7 @@ class RandomChooseData(DataFlow):
...
@@ -226,7 +227,7 @@ class RandomChooseData(DataFlow):
except
StopIteration
:
except
StopIteration
:
return
return
class
RandomMixData
(
DataFlow
):
class
RandomMixData
(
RNG
DataFlow
):
"""
"""
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
"""
"""
...
@@ -235,14 +236,14 @@ class RandomMixData(DataFlow):
...
@@ -235,14 +236,14 @@ class RandomMixData(DataFlow):
:param df_lists: list of dataflow.
:param df_lists: list of dataflow.
All DataFlow in `df_lists` must have :func:`size()` implemented
All DataFlow in `df_lists` must have :func:`size()` implemented
"""
"""
super
(
RandomMixData
,
self
)
.
__init__
()
self
.
df_lists
=
df_lists
self
.
df_lists
=
df_lists
self
.
sizes
=
[
k
.
size
()
for
k
in
self
.
df_lists
]
self
.
sizes
=
[
k
.
size
()
for
k
in
self
.
df_lists
]
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
RandomMixData
,
self
)
.
reset_state
()
for
d
in
self
.
df_lists
:
for
d
in
self
.
df_lists
:
d
.
reset_state
()
d
.
reset_state
()
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
def
size
(
self
):
return
sum
(
self
.
sizes
)
return
sum
(
self
.
sizes
)
...
@@ -318,9 +319,22 @@ class JoinData(DataFlow):
...
@@ -318,9 +319,22 @@ class JoinData(DataFlow):
for
itr
in
itrs
:
for
itr
in
itrs
:
del
itr
del
itr
class
LocallyShuffleData
(
ProxyDataFlow
,
RNGDataFlow
):
def
__init__
(
self
,
ds
,
cache_size
):
ProxyDataFlow
.
__init__
(
self
,
ds
)
RNGDataFlow
.
__init__
(
self
)
self
.
cache_size
=
cache_size
self
.
q
=
deque
(
maxlen
=
self
.
cache_size
)
def
reset_state
(
self
):
RNGDataFlow
.
reset_state
(
self
)
def
get_data
(
self
):
# TODO
pass
class
DataFromQueue
(
DataFlow
):
class
DataFromQueue
(
DataFlow
):
""" provide data from a queue
""" Provide data from a queue """
"""
def
__init__
(
self
,
queue
):
def
__init__
(
self
,
queue
):
self
.
queue
=
queue
self
.
queue
=
queue
...
...
tensorpack/tfutils/sessinit.py
View file @
1126fa5c
...
@@ -51,7 +51,7 @@ class SaverRestore(SessionInit):
...
@@ -51,7 +51,7 @@ class SaverRestore(SessionInit):
"""
"""
Restore an old model saved by `ModelSaver`.
Restore an old model saved by `ModelSaver`.
"""
"""
def
__init__
(
self
,
model_path
):
def
__init__
(
self
,
model_path
,
prefix
=
None
):
"""
"""
:param model_path: a model file or a ``checkpoint`` file.
:param model_path: a model file or a ``checkpoint`` file.
"""
"""
...
@@ -61,12 +61,13 @@ class SaverRestore(SessionInit):
...
@@ -61,12 +61,13 @@ class SaverRestore(SessionInit):
os
.
path
.
dirname
(
model_path
))
.
model_checkpoint_path
os
.
path
.
dirname
(
model_path
))
.
model_checkpoint_path
assert
os
.
path
.
isfile
(
model_path
)
assert
os
.
path
.
isfile
(
model_path
)
self
.
set_path
(
model_path
)
self
.
set_path
(
model_path
)
self
.
prefix
=
prefix
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
logger
.
info
(
logger
.
info
(
"Restoring checkpoint from {}."
.
format
(
self
.
path
))
"Restoring checkpoint from {}."
.
format
(
self
.
path
))
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
vars_map
=
SaverRestore
.
_get_vars_to_restore_multimap
(
chkpt_vars
)
vars_map
=
self
.
_get_vars_to_restore_multimap
(
chkpt_vars
)
for
dic
in
SaverRestore
.
_produce_restore_dict
(
vars_map
):
for
dic
in
SaverRestore
.
_produce_restore_dict
(
vars_map
):
# multiple saver under same name scope would cause error:
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
...
@@ -93,6 +94,7 @@ class SaverRestore(SessionInit):
...
@@ -93,6 +94,7 @@ class SaverRestore(SessionInit):
@
staticmethod
@
staticmethod
def
_read_checkpoint_vars
(
model_path
):
def
_read_checkpoint_vars
(
model_path
):
""" return a set of strings """
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
for
v
in
ckpt_vars
:
for
v
in
ckpt_vars
:
...
@@ -100,11 +102,10 @@ class SaverRestore(SessionInit):
...
@@ -100,11 +102,10 @@ class SaverRestore(SessionInit):
logger
.
warn
(
"Found {} in checkpoint. Anything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
logger
.
warn
(
"Found {} in checkpoint. Anything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
return
set
(
ckpt_vars
)
return
set
(
ckpt_vars
)
@
staticmethod
def
_get_vars_to_restore_multimap
(
self
,
vars_available
):
def
_get_vars_to_restore_multimap
(
vars_available
):
"""
"""
Get a dict of {var_name: [var, var]} to restore
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking
:param vars_available: varaible
name
s available in the checkpoint, for existence checking
"""
"""
vars_to_restore
=
tf
.
all_variables
()
vars_to_restore
=
tf
.
all_variables
()
var_dict
=
defaultdict
(
list
)
var_dict
=
defaultdict
(
list
)
...
@@ -117,6 +118,8 @@ class SaverRestore(SessionInit):
...
@@ -117,6 +118,8 @@ class SaverRestore(SessionInit):
if
'tower'
in
name
:
if
'tower'
in
name
:
new_name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
new_name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
name
=
new_name
name
=
new_name
if
self
.
prefix
and
name
.
startswith
(
self
.
prefix
):
name
=
name
[
len
(
self
.
prefix
)
+
1
:]
if
name
in
vars_available
:
if
name
in
vars_available
:
var_dict
[
name
]
.
append
(
v
)
var_dict
[
name
]
.
append
(
v
)
vars_available
.
remove
(
name
)
vars_available
.
remove
(
name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment