Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1095c8b8
Commit
1095c8b8
authored
Jul 12, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'reset-state'
parents
6607d856
2f3b8502
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
47 additions
and
57 deletions
+47
-57
examples/DisturbLabel/disturb.py
examples/DisturbLabel/disturb.py
+2
-0
scripts/dump_train_config.py
scripts/dump_train_config.py
+1
-0
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+3
-3
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+1
-0
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+1
-4
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+0
-4
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+2
-6
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+2
-6
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+2
-9
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+3
-3
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+2
-6
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+3
-0
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+7
-8
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+9
-1
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+1
-0
tensorpack/tfutils/modelutils.py
tensorpack/tfutils/modelutils.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+7
-6
No files found.
examples/DisturbLabel/disturb.py
View file @
1095c8b8
...
@@ -9,6 +9,8 @@ class DisturbLabel(ProxyDataFlow):
...
@@ -9,6 +9,8 @@ class DisturbLabel(ProxyDataFlow):
def
__init__
(
self
,
ds
,
prob
):
def
__init__
(
self
,
ds
,
prob
):
super
(
DisturbLabel
,
self
)
.
__init__
(
ds
)
super
(
DisturbLabel
,
self
)
.
__init__
(
ds
)
self
.
prob
=
prob
self
.
prob
=
prob
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
def
get_data
(
self
):
...
...
scripts/dump_train_config.py
View file @
1095c8b8
...
@@ -29,6 +29,7 @@ args = parser.parse_args()
...
@@ -29,6 +29,7 @@ args = parser.parse_args()
get_config_func
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
get_config
get_config_func
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
get_config
config
=
get_config_func
()
config
=
get_config_func
()
config
.
dataset
.
reset_state
()
if
args
.
output
:
if
args
.
output
:
mkdir_p
(
args
.
output
)
mkdir_p
(
args
.
output
)
...
...
tensorpack/RL/expreplay.py
View file @
1095c8b8
...
@@ -24,6 +24,9 @@ class ExpReplay(DataFlow, Callback):
...
@@ -24,6 +24,9 @@ class ExpReplay(DataFlow, Callback):
"""
"""
Implement experience replay in the paper
Implement experience replay in the paper
`Human-level control through deep reinforcement learning`.
`Human-level control through deep reinforcement learning`.
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
predictor
,
predictor
,
...
@@ -80,9 +83,6 @@ class ExpReplay(DataFlow, Callback):
...
@@ -80,9 +83,6 @@ class ExpReplay(DataFlow, Callback):
pbar
.
update
()
pbar
.
update
()
self
.
_init_memory_flag
.
set
()
self
.
_init_memory_flag
.
set
()
def
reset_state
(
self
):
raise
RuntimeError
(
"Don't run me in multiple processes"
)
def
_populate_exp
(
self
):
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
""" populate a transition by epsilon-greedy"""
old_s
=
self
.
player
.
current_state
()
old_s
=
self
.
player
.
current_state
()
...
...
tensorpack/callbacks/inference.py
View file @
1095c8b8
...
@@ -106,6 +106,7 @@ class InferenceRunner(Callback):
...
@@ -106,6 +106,7 @@ class InferenceRunner(Callback):
vc
.
before_inference
()
vc
.
before_inference
()
sess
=
tf
.
get_default_session
()
sess
=
tf
.
get_default_session
()
self
.
ds
.
reset_state
()
with
tqdm
(
total
=
self
.
ds
.
size
(),
ascii
=
True
)
as
pbar
:
with
tqdm
(
total
=
self
.
ds
.
size
(),
ascii
=
True
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
...
...
tensorpack/dataflow/base.py
View file @
1095c8b8
...
@@ -29,7 +29,7 @@ class DataFlow(object):
...
@@ -29,7 +29,7 @@ class DataFlow(object):
def
reset_state
(
self
):
def
reset_state
(
self
):
"""
"""
Reset state of the dataflow
,
Reset state of the dataflow
. Will always be called before consuming data points.
for example, RNG **HAS** to be reset here if used in the DataFlow.
for example, RNG **HAS** to be reset here if used in the DataFlow.
Otherwise it may not work well with prefetching, because different
Otherwise it may not work well with prefetching, because different
processes will have the same RNG state.
processes will have the same RNG state.
...
@@ -39,9 +39,6 @@ class DataFlow(object):
...
@@ -39,9 +39,6 @@ class DataFlow(object):
class
RNGDataFlow
(
DataFlow
):
class
RNGDataFlow
(
DataFlow
):
""" A dataflow with rng"""
""" A dataflow with rng"""
def
__init__
(
self
):
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
...
...
tensorpack/dataflow/common.py
View file @
1095c8b8
...
@@ -306,11 +306,7 @@ class JoinData(DataFlow):
...
@@ -306,11 +306,7 @@ class JoinData(DataFlow):
class
LocallyShuffleData
(
ProxyDataFlow
,
RNGDataFlow
):
class
LocallyShuffleData
(
ProxyDataFlow
,
RNGDataFlow
):
def
__init__
(
self
,
ds
,
cache_size
):
def
__init__
(
self
,
ds
,
cache_size
):
ProxyDataFlow
.
__init__
(
self
,
ds
)
ProxyDataFlow
.
__init__
(
self
,
ds
)
RNGDataFlow
.
__init__
(
self
)
self
.
q
=
deque
(
maxlen
=
cache_size
)
self
.
q
=
deque
(
maxlen
=
cache_size
)
self
.
ds_wrap
=
RepeatedData
(
ds
,
-
1
)
self
.
ds_itr
=
self
.
ds_wrap
.
get_data
()
self
.
current_cnt
=
0
def
reset_state
(
self
):
def
reset_state
(
self
):
ProxyDataFlow
.
reset_state
(
self
)
ProxyDataFlow
.
reset_state
(
self
)
...
...
tensorpack/dataflow/dataset/bsds500.py
View file @
1095c8b8
...
@@ -9,7 +9,7 @@ import numpy as np
...
@@ -9,7 +9,7 @@ import numpy as np
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils.fs
import
download
from
...utils.fs
import
download
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
try
:
try
:
from
scipy.io
import
loadmat
from
scipy.io
import
loadmat
...
@@ -21,7 +21,7 @@ except ImportError:
...
@@ -21,7 +21,7 @@ except ImportError:
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W
,
IMG_H
=
481
,
321
IMG_W
,
IMG_H
=
481
,
321
class
BSDS500
(
DataFlow
):
class
BSDS500
(
RNG
DataFlow
):
"""
"""
`Berkeley Segmentation Data Set and Benchmarks 500
`Berkeley Segmentation Data Set and Benchmarks 500
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
...
@@ -53,10 +53,6 @@ class BSDS500(DataFlow):
...
@@ -53,10 +53,6 @@ class BSDS500(DataFlow):
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
assert
name
in
[
'train'
,
'test'
,
'val'
]
assert
name
in
[
'train'
,
'test'
,
'val'
]
self
.
_load
(
name
)
self
.
_load
(
name
)
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
_load
(
self
,
name
):
def
_load
(
self
,
name
):
image_glob
=
os
.
path
.
join
(
self
.
data_root
,
'images'
,
name
,
'*.jpg'
)
image_glob
=
os
.
path
.
join
(
self
.
data_root
,
'images'
,
name
,
'*.jpg'
)
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
1095c8b8
...
@@ -15,7 +15,7 @@ import logging
...
@@ -15,7 +15,7 @@ import logging
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils.fs
import
download
from
...utils.fs
import
download
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
__all__
=
[
'Cifar10'
,
'Cifar100'
]
__all__
=
[
'Cifar10'
,
'Cifar100'
]
...
@@ -77,7 +77,7 @@ def get_filenames(dir, cifar_classnum):
...
@@ -77,7 +77,7 @@ def get_filenames(dir, cifar_classnum):
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
return
filenames
return
filenames
class
CifarBase
(
DataFlow
):
class
CifarBase
(
RNG
DataFlow
):
"""
"""
Return [image, label],
Return [image, label],
image is 32x32x3 in the range [0,255]
image is 32x32x3 in the range [0,255]
...
@@ -106,10 +106,6 @@ class CifarBase(DataFlow):
...
@@ -106,10 +106,6 @@ class CifarBase(DataFlow):
self
.
data
=
read_cifar
(
self
.
fs
,
cifar_classnum
)
self
.
data
=
read_cifar
(
self
.
fs
,
cifar_classnum
)
self
.
dir
=
dir
self
.
dir
=
dir
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
def
size
(
self
):
return
50000
if
self
.
train_or_test
==
'train'
else
10000
return
50000
if
self
.
train_or_test
==
'train'
else
10000
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
1095c8b8
...
@@ -11,7 +11,7 @@ from six.moves import range
...
@@ -11,7 +11,7 @@ from six.moves import range
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
,
memoized
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
,
memoized
from
...utils.loadcaffe
import
get_caffe_pb
from
...utils.loadcaffe
import
get_caffe_pb
from
...utils.fs
import
mkdir_p
,
download
from
...utils.fs
import
mkdir_p
,
download
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
__all__
=
[
'ILSVRCMeta'
,
'ILSVRC12'
]
__all__
=
[
'ILSVRCMeta'
,
'ILSVRC12'
]
...
@@ -79,7 +79,7 @@ class ILSVRCMeta(object):
...
@@ -79,7 +79,7 @@ class ILSVRCMeta(object):
arr
=
cv2
.
resize
(
arr
,
size
[::
-
1
])
arr
=
cv2
.
resize
(
arr
,
size
[::
-
1
])
return
arr
return
arr
class
ILSVRC12
(
DataFlow
):
class
ILSVRC12
(
RNG
DataFlow
):
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
):
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
):
"""
"""
:param dir: A directory containing a subdir named `name`, where the
:param dir: A directory containing a subdir named `name`, where the
...
@@ -119,17 +119,10 @@ class ILSVRC12(DataFlow):
...
@@ -119,17 +119,10 @@ class ILSVRC12(DataFlow):
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
meta
=
ILSVRCMeta
(
meta_dir
)
self
.
meta
=
ILSVRCMeta
(
meta_dir
)
self
.
imglist
=
self
.
meta
.
get_image_list
(
name
)
self
.
imglist
=
self
.
meta
.
get_image_list
(
name
)
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
def
size
(
self
):
return
len
(
self
.
imglist
)
return
len
(
self
.
imglist
)
def
reset_state
(
self
):
"""
reset rng for shuffle
"""
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
def
get_data
(
self
):
"""
"""
Produce original images or shape [h, w, 3], and label
Produce original images or shape [h, w, 3], and label
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
1095c8b8
...
@@ -11,7 +11,7 @@ from six.moves import urllib, range
...
@@ -11,7 +11,7 @@ from six.moves import urllib, range
from
...utils
import
logger
,
get_dataset_dir
from
...utils
import
logger
,
get_dataset_dir
from
...utils.fs
import
download
from
...utils.fs
import
download
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
__all__
=
[
'Mnist'
]
__all__
=
[
'Mnist'
]
...
@@ -92,7 +92,7 @@ class DataSet(object):
...
@@ -92,7 +92,7 @@ class DataSet(object):
def
num_examples
(
self
):
def
num_examples
(
self
):
return
self
.
_num_examples
return
self
.
_num_examples
class
Mnist
(
DataFlow
):
class
Mnist
(
RNG
DataFlow
):
"""
"""
Return [image, label],
Return [image, label],
image is 28x28 in the range [0,1]
image is 28x28 in the range [0,1]
...
@@ -136,7 +136,7 @@ class Mnist(DataFlow):
...
@@ -136,7 +136,7 @@ class Mnist(DataFlow):
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
idxs
=
list
(
range
(
ds
.
num_examples
))
idxs
=
list
(
range
(
ds
.
num_examples
))
if
self
.
shuffle
:
if
self
.
shuffle
:
random
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
for
k
in
idxs
:
img
=
ds
.
images
[
k
]
.
reshape
((
28
,
28
))
img
=
ds
.
images
[
k
]
.
reshape
((
28
,
28
))
label
=
ds
.
labels
[
k
]
label
=
ds
.
labels
[
k
]
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
1095c8b8
...
@@ -9,7 +9,7 @@ import numpy as np
...
@@ -9,7 +9,7 @@ import numpy as np
from
six.moves
import
range
from
six.moves
import
range
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
try
:
try
:
import
scipy.io
import
scipy.io
...
@@ -20,7 +20,7 @@ except ImportError:
...
@@ -20,7 +20,7 @@ except ImportError:
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
class
SVHNDigit
(
DataFlow
):
class
SVHNDigit
(
RNG
DataFlow
):
"""
"""
SVHN Cropped Digit Dataset
SVHN Cropped Digit Dataset
return img of 32x32x3, label of 0-9
return img of 32x32x3, label of 0-9
...
@@ -33,7 +33,6 @@ class SVHNDigit(DataFlow):
...
@@ -33,7 +33,6 @@ class SVHNDigit(DataFlow):
:param data_dir: a directory containing the original {train,test,extra}_32x32.mat
:param data_dir: a directory containing the original {train,test,extra}_32x32.mat
"""
"""
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
rng
=
get_rng
(
self
)
if
name
in
SVHNDigit
.
Cache
:
if
name
in
SVHNDigit
.
Cache
:
self
.
X
,
self
.
Y
=
SVHNDigit
.
Cache
[
name
]
self
.
X
,
self
.
Y
=
SVHNDigit
.
Cache
[
name
]
...
@@ -54,9 +53,6 @@ class SVHNDigit(DataFlow):
...
@@ -54,9 +53,6 @@ class SVHNDigit(DataFlow):
def
size
(
self
):
def
size
(
self
):
return
self
.
X
.
shape
[
0
]
return
self
.
X
.
shape
[
0
]
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
def
get_data
(
self
):
n
=
self
.
X
.
shape
[
0
]
n
=
self
.
X
.
shape
[
0
]
idxs
=
np
.
arange
(
n
)
idxs
=
np
.
arange
(
n
)
...
...
tensorpack/dataflow/dftools.py
View file @
1095c8b8
...
@@ -23,6 +23,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
...
@@ -23,6 +23,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
mkdir_p
(
dirname
)
mkdir_p
(
dirname
)
if
max_count
is
None
:
if
max_count
is
None
:
max_count
=
sys
.
maxint
max_count
=
sys
.
maxint
ds
.
reset_state
()
for
i
,
dp
in
enumerate
(
ds
.
get_data
()):
for
i
,
dp
in
enumerate
(
ds
.
get_data
()):
if
i
%
100
==
0
:
if
i
%
100
==
0
:
print
(
i
)
print
(
i
)
...
@@ -34,6 +35,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
...
@@ -34,6 +35,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
def
dataflow_to_process_queue
(
ds
,
size
,
nr_consumer
):
def
dataflow_to_process_queue
(
ds
,
size
,
nr_consumer
):
"""
"""
Convert a `DataFlow` to a multiprocessing.Queue.
Convert a `DataFlow` to a multiprocessing.Queue.
The dataflow will only be reset in the spawned process.
:param ds: a `DataFlow`
:param ds: a `DataFlow`
:param size: size of the queue
:param size: size of the queue
...
@@ -50,6 +52,7 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
...
@@ -50,6 +52,7 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
self
.
q
=
q
self
.
q
=
q
def
run
(
self
):
def
run
(
self
):
self
.
ds
.
reset_state
()
try
:
try
:
for
idx
,
dp
in
enumerate
(
self
.
ds
.
get_data
()):
for
idx
,
dp
in
enumerate
(
self
.
ds
.
get_data
()):
self
.
q
.
put
((
idx
,
dp
))
self
.
q
.
put
((
idx
,
dp
))
...
...
tensorpack/dataflow/format.py
View file @
1095c8b8
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
from
..utils
import
logger
,
get_rng
from
..utils
import
logger
,
get_rng
from
..utils.timer
import
timed_operation
from
..utils.timer
import
timed_operation
from
..utils.loadcaffe
import
get_caffe_pb
from
..utils.loadcaffe
import
get_caffe_pb
from
.base
import
DataFlow
from
.base
import
RNG
DataFlow
import
random
import
random
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -31,7 +31,7 @@ else:
...
@@ -31,7 +31,7 @@ else:
Adapters for different data format.
Adapters for different data format.
"""
"""
class
HDF5Data
(
DataFlow
):
class
HDF5Data
(
RNG
DataFlow
):
"""
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
Zip data from different paths in an HDF5 file. Will load all data into memory.
"""
"""
...
@@ -55,19 +55,18 @@ class HDF5Data(DataFlow):
...
@@ -55,19 +55,18 @@ class HDF5Data(DataFlow):
def
get_data
(
self
):
def
get_data
(
self
):
idxs
=
list
(
range
(
self
.
_size
))
idxs
=
list
(
range
(
self
.
_size
))
if
self
.
shuffle
:
if
self
.
shuffle
:
random
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
for
k
in
idxs
:
yield
[
dp
[
k
]
for
dp
in
self
.
dps
]
yield
[
dp
[
k
]
for
dp
in
self
.
dps
]
class
LMDBData
(
DataFlow
):
class
LMDBData
(
RNG
DataFlow
):
""" Read a lmdb and produce k,v pair """
""" Read a lmdb and produce k,v pair """
def
__init__
(
self
,
lmdb_dir
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_dir
,
shuffle
=
True
):
self
.
_lmdb
=
lmdb
.
open
(
lmdb_dir
,
readonly
=
True
,
lock
=
False
,
self
.
_lmdb
=
lmdb
.
open
(
lmdb_dir
,
readonly
=
True
,
lock
=
False
,
map_size
=
1099511627776
*
2
,
max_readers
=
100
)
map_size
=
1099511627776
*
2
,
max_readers
=
100
)
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
self
.
rng
=
get_rng
(
self
)
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
if
shuffle
:
if
shuffle
:
self
.
keys
=
self
.
_txn
.
get
(
'__keys__'
)
self
.
keys
=
self
.
_txn
.
get
(
'__keys__'
)
...
@@ -81,8 +80,8 @@ class LMDBData(DataFlow):
...
@@ -81,8 +80,8 @@ class LMDBData(DataFlow):
pbar
.
update
()
pbar
.
update
()
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
LMDBData
,
self
)
.
reset_state
()
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
def
size
(
self
):
return
self
.
_size
return
self
.
_size
...
@@ -96,8 +95,8 @@ class LMDBData(DataFlow):
...
@@ -96,8 +95,8 @@ class LMDBData(DataFlow):
yield
[
k
,
v
]
yield
[
k
,
v
]
else
:
else
:
s
=
self
.
size
()
s
=
self
.
size
()
for
i
in
range
(
s
):
self
.
rng
.
shuffle
(
self
.
keys
)
k
=
self
.
rng
.
choice
(
self
.
keys
)
for
k
in
self
.
keys
:
v
=
self
.
_txn
.
get
(
k
)
v
=
self
.
_txn
.
get
(
k
)
yield
[
k
,
v
]
yield
[
k
,
v
]
...
...
tensorpack/dataflow/prefetch.py
View file @
1095c8b8
...
@@ -35,7 +35,7 @@ class PrefetchProcess(multiprocessing.Process):
...
@@ -35,7 +35,7 @@ class PrefetchProcess(multiprocessing.Process):
self
.
queue
=
queue
self
.
queue
=
queue
def
run
(
self
):
def
run
(
self
):
# reset
RNG of
ds so each process will produce different data
# reset
all
ds so each process will produce different data
self
.
ds
.
reset_state
()
self
.
ds
.
reset_state
()
while
True
:
while
True
:
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
...
@@ -73,6 +73,10 @@ class PrefetchData(ProxyDataFlow):
...
@@ -73,6 +73,10 @@ class PrefetchData(ProxyDataFlow):
dp
=
self
.
queue
.
get
()
dp
=
self
.
queue
.
get
()
yield
dp
yield
dp
def
reset_state
(
self
):
# do nothing. all ds are reset once and only once in spawned processes
pass
class
PrefetchProcessZMQ
(
multiprocessing
.
Process
):
class
PrefetchProcessZMQ
(
multiprocessing
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
):
def
__init__
(
self
,
ds
,
conn_name
):
"""
"""
...
@@ -134,6 +138,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -134,6 +138,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
))
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
))
yield
dp
yield
dp
def
reset_state
(
self
):
# do nothing. all ds are reset once and only once in spawned processes
pass
def
__del__
(
self
):
def
__del__
(
self
):
# on exit, logger may not be functional anymore
# on exit, logger may not be functional anymore
try
:
try
:
...
...
tensorpack/dataflow/remote.py
View file @
1095c8b8
...
@@ -23,6 +23,7 @@ def serve_data(ds, addr):
...
@@ -23,6 +23,7 @@ def serve_data(ds, addr):
socket
.
bind
(
addr
)
socket
.
bind
(
addr
)
ds
=
RepeatedData
(
ds
,
-
1
)
ds
=
RepeatedData
(
ds
,
-
1
)
try
:
try
:
ds
.
reset_state
()
logger
.
info
(
"Serving data at {}"
.
format
(
addr
))
logger
.
info
(
"Serving data at {}"
.
format
(
addr
))
while
True
:
while
True
:
for
dp
in
ds
.
get_data
():
for
dp
in
ds
.
get_data
():
...
...
tensorpack/tfutils/modelutils.py
View file @
1095c8b8
...
@@ -19,7 +19,7 @@ def describe_model():
...
@@ -19,7 +19,7 @@ def describe_model():
v
.
name
,
shape
.
as_list
(),
ele
))
v
.
name
,
shape
.
as_list
(),
ele
))
size_mb
=
total
*
4
/
1024.0
**
2
size_mb
=
total
*
4
/
1024.0
**
2
msg
.
append
(
"Total param={} ({:01f} MB assuming all float32)"
.
format
(
total
,
size_mb
))
msg
.
append
(
"Total param={} ({:01f} MB assuming all float32)"
.
format
(
total
,
size_mb
))
logger
.
info
(
"Model Params: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
logger
.
info
(
"Model Param
eter
s: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
def
get_shape_str
(
tensors
):
def
get_shape_str
(
tensors
):
...
...
tensorpack/train/trainer.py
View file @
1095c8b8
...
@@ -40,6 +40,7 @@ class SimpleTrainer(Trainer):
...
@@ -40,6 +40,7 @@ class SimpleTrainer(Trainer):
self
.
init_session_and_coord
()
self
.
init_session_and_coord
()
describe_model
()
describe_model
()
# create an infinte data producer
# create an infinte data producer
self
.
config
.
dataset
.
reset_state
()
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
main_loop
()
self
.
main_loop
()
...
@@ -62,21 +63,22 @@ class SimpleTrainer(Trainer):
...
@@ -62,21 +63,22 @@ class SimpleTrainer(Trainer):
return
func
return
func
class
EnqueueThread
(
threading
.
Thread
):
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
enqueue_op
,
raw_input_var
):
def
__init__
(
self
,
trainer
):
super
(
EnqueueThread
,
self
)
.
__init__
()
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
sess
=
trainer
.
sess
self
.
sess
=
trainer
.
sess
self
.
coord
=
trainer
.
coord
self
.
coord
=
trainer
.
coord
self
.
dataflow
=
RepeatedData
(
trainer
.
config
.
dataset
,
-
1
)
self
.
dataflow
=
RepeatedData
(
trainer
.
config
.
dataset
,
-
1
)
self
.
input_vars
=
raw_input_var
self
.
input_vars
=
trainer
.
input_vars
self
.
op
=
enqueue_op
self
.
queue
=
trainer
.
input_queue
self
.
queue
=
queue
self
.
op
=
self
.
queue
.
enqueue
(
self
.
input_vars
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
size_op
=
self
.
queue
.
size
()
self
.
size_op
=
self
.
queue
.
size
()
self
.
daemon
=
True
self
.
daemon
=
True
def
run
(
self
):
def
run
(
self
):
self
.
dataflow
.
reset_state
()
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
try
:
try
:
while
True
:
while
True
:
...
@@ -155,8 +157,7 @@ class QueueInputTrainer(Trainer):
...
@@ -155,8 +157,7 @@ class QueueInputTrainer(Trainer):
def
_build_enque_thread
(
self
):
def
_build_enque_thread
(
self
):
""" create a thread that keeps filling the queue """
""" create a thread that keeps filling the queue """
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
self
.
input_th
=
EnqueueThread
(
self
)
self
.
input_th
=
EnqueueThread
(
self
,
self
.
input_queue
,
enqueue_op
,
self
.
input_vars
)
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
def
train
(
self
):
def
train
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment