Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2f3b8502
Commit
2f3b8502
authored
Jul 12, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
always use reset_state
parent
6607d856
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
47 additions
and
57 deletions
+47
-57
examples/DisturbLabel/disturb.py
examples/DisturbLabel/disturb.py
+2
-0
scripts/dump_train_config.py
scripts/dump_train_config.py
+1
-0
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+3
-3
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+1
-0
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+1
-4
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+0
-4
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+2
-6
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+2
-6
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+2
-9
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+3
-3
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+2
-6
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+3
-0
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+7
-8
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+9
-1
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+1
-0
tensorpack/tfutils/modelutils.py
tensorpack/tfutils/modelutils.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+7
-6
No files found.
examples/DisturbLabel/disturb.py
View file @
2f3b8502
...
...
@@ -9,6 +9,8 @@ class DisturbLabel(ProxyDataFlow):
def
__init__
(
self
,
ds
,
prob
):
super
(
DisturbLabel
,
self
)
.
__init__
(
ds
)
self
.
prob
=
prob
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
...
...
scripts/dump_train_config.py
View file @
2f3b8502
...
...
@@ -29,6 +29,7 @@ args = parser.parse_args()
get_config_func
=
imp
.
load_source
(
'config_script'
,
args
.
config
)
.
get_config
config
=
get_config_func
()
config
.
dataset
.
reset_state
()
if
args
.
output
:
mkdir_p
(
args
.
output
)
...
...
tensorpack/RL/expreplay.py
View file @
2f3b8502
...
...
@@ -24,6 +24,9 @@ class ExpReplay(DataFlow, Callback):
"""
Implement experience replay in the paper
`Human-level control through deep reinforcement learning`.
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
def
__init__
(
self
,
predictor
,
...
...
@@ -80,9 +83,6 @@ class ExpReplay(DataFlow, Callback):
pbar
.
update
()
self
.
_init_memory_flag
.
set
()
def
reset_state
(
self
):
raise
RuntimeError
(
"Don't run me in multiple processes"
)
def
_populate_exp
(
self
):
""" populate a transition by epsilon-greedy"""
old_s
=
self
.
player
.
current_state
()
...
...
tensorpack/callbacks/inference.py
View file @
2f3b8502
...
...
@@ -106,6 +106,7 @@ class InferenceRunner(Callback):
vc
.
before_inference
()
sess
=
tf
.
get_default_session
()
self
.
ds
.
reset_state
()
with
tqdm
(
total
=
self
.
ds
.
size
(),
ascii
=
True
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
...
...
tensorpack/dataflow/base.py
View file @
2f3b8502
...
...
@@ -29,7 +29,7 @@ class DataFlow(object):
def
reset_state
(
self
):
"""
Reset state of the dataflow
,
Reset state of the dataflow
. Will always be called before consuming data points.
for example, RNG **HAS** to be reset here if used in the DataFlow.
Otherwise it may not work well with prefetching, because different
processes will have the same RNG state.
...
...
@@ -39,9 +39,6 @@ class DataFlow(object):
class
RNGDataFlow
(
DataFlow
):
""" A dataflow with rng"""
def
__init__
(
self
):
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
...
...
tensorpack/dataflow/common.py
View file @
2f3b8502
...
...
@@ -306,11 +306,7 @@ class JoinData(DataFlow):
class
LocallyShuffleData
(
ProxyDataFlow
,
RNGDataFlow
):
def
__init__
(
self
,
ds
,
cache_size
):
ProxyDataFlow
.
__init__
(
self
,
ds
)
RNGDataFlow
.
__init__
(
self
)
self
.
q
=
deque
(
maxlen
=
cache_size
)
self
.
ds_wrap
=
RepeatedData
(
ds
,
-
1
)
self
.
ds_itr
=
self
.
ds_wrap
.
get_data
()
self
.
current_cnt
=
0
def
reset_state
(
self
):
ProxyDataFlow
.
reset_state
(
self
)
...
...
tensorpack/dataflow/dataset/bsds500.py
View file @
2f3b8502
...
...
@@ -9,7 +9,7 @@ import numpy as np
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils.fs
import
download
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
try
:
from
scipy.io
import
loadmat
...
...
@@ -21,7 +21,7 @@ except ImportError:
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W
,
IMG_H
=
481
,
321
class
BSDS500
(
DataFlow
):
class
BSDS500
(
RNG
DataFlow
):
"""
`Berkeley Segmentation Data Set and Benchmarks 500
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
...
...
@@ -53,10 +53,6 @@ class BSDS500(DataFlow):
self
.
shuffle
=
shuffle
assert
name
in
[
'train'
,
'test'
,
'val'
]
self
.
_load
(
name
)
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
_load
(
self
,
name
):
image_glob
=
os
.
path
.
join
(
self
.
data_root
,
'images'
,
name
,
'*.jpg'
)
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
2f3b8502
...
...
@@ -15,7 +15,7 @@ import logging
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
...utils.fs
import
download
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
__all__
=
[
'Cifar10'
,
'Cifar100'
]
...
...
@@ -77,7 +77,7 @@ def get_filenames(dir, cifar_classnum):
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
return
filenames
class
CifarBase
(
DataFlow
):
class
CifarBase
(
RNG
DataFlow
):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
...
...
@@ -106,10 +106,6 @@ class CifarBase(DataFlow):
self
.
data
=
read_cifar
(
self
.
fs
,
cifar_classnum
)
self
.
dir
=
dir
self
.
shuffle
=
shuffle
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
return
50000
if
self
.
train_or_test
==
'train'
else
10000
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
2f3b8502
...
...
@@ -11,7 +11,7 @@ from six.moves import range
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
,
memoized
from
...utils.loadcaffe
import
get_caffe_pb
from
...utils.fs
import
mkdir_p
,
download
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
__all__
=
[
'ILSVRCMeta'
,
'ILSVRC12'
]
...
...
@@ -79,7 +79,7 @@ class ILSVRCMeta(object):
arr
=
cv2
.
resize
(
arr
,
size
[::
-
1
])
return
arr
class
ILSVRC12
(
DataFlow
):
class
ILSVRC12
(
RNG
DataFlow
):
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
):
"""
:param dir: A directory containing a subdir named `name`, where the
...
...
@@ -119,17 +119,10 @@ class ILSVRC12(DataFlow):
self
.
shuffle
=
shuffle
self
.
meta
=
ILSVRCMeta
(
meta_dir
)
self
.
imglist
=
self
.
meta
.
get_image_list
(
name
)
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
return
len
(
self
.
imglist
)
def
reset_state
(
self
):
"""
reset rng for shuffle
"""
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
"""
Produce original images or shape [h, w, 3], and label
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
2f3b8502
...
...
@@ -11,7 +11,7 @@ from six.moves import urllib, range
from
...utils
import
logger
,
get_dataset_dir
from
...utils.fs
import
download
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
__all__
=
[
'Mnist'
]
...
...
@@ -92,7 +92,7 @@ class DataSet(object):
def
num_examples
(
self
):
return
self
.
_num_examples
class
Mnist
(
DataFlow
):
class
Mnist
(
RNG
DataFlow
):
"""
Return [image, label],
image is 28x28 in the range [0,1]
...
...
@@ -136,7 +136,7 @@ class Mnist(DataFlow):
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
idxs
=
list
(
range
(
ds
.
num_examples
))
if
self
.
shuffle
:
random
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
img
=
ds
.
images
[
k
]
.
reshape
((
28
,
28
))
label
=
ds
.
labels
[
k
]
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
2f3b8502
...
...
@@ -9,7 +9,7 @@ import numpy as np
from
six.moves
import
range
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
from
..base
import
DataFlow
from
..base
import
RNG
DataFlow
try
:
import
scipy.io
...
...
@@ -20,7 +20,7 @@ except ImportError:
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
class
SVHNDigit
(
DataFlow
):
class
SVHNDigit
(
RNG
DataFlow
):
"""
SVHN Cropped Digit Dataset
return img of 32x32x3, label of 0-9
...
...
@@ -33,7 +33,6 @@ class SVHNDigit(DataFlow):
:param data_dir: a directory containing the original {train,test,extra}_32x32.mat
"""
self
.
shuffle
=
shuffle
self
.
rng
=
get_rng
(
self
)
if
name
in
SVHNDigit
.
Cache
:
self
.
X
,
self
.
Y
=
SVHNDigit
.
Cache
[
name
]
...
...
@@ -54,9 +53,6 @@ class SVHNDigit(DataFlow):
def
size
(
self
):
return
self
.
X
.
shape
[
0
]
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
n
=
self
.
X
.
shape
[
0
]
idxs
=
np
.
arange
(
n
)
...
...
tensorpack/dataflow/dftools.py
View file @
2f3b8502
...
...
@@ -23,6 +23,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
mkdir_p
(
dirname
)
if
max_count
is
None
:
max_count
=
sys
.
maxint
ds
.
reset_state
()
for
i
,
dp
in
enumerate
(
ds
.
get_data
()):
if
i
%
100
==
0
:
print
(
i
)
...
...
@@ -34,6 +35,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
def
dataflow_to_process_queue
(
ds
,
size
,
nr_consumer
):
"""
Convert a `DataFlow` to a multiprocessing.Queue.
The dataflow will only be reset in the spawned process.
:param ds: a `DataFlow`
:param size: size of the queue
...
...
@@ -50,6 +52,7 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
self
.
q
=
q
def
run
(
self
):
self
.
ds
.
reset_state
()
try
:
for
idx
,
dp
in
enumerate
(
self
.
ds
.
get_data
()):
self
.
q
.
put
((
idx
,
dp
))
...
...
tensorpack/dataflow/format.py
View file @
2f3b8502
...
...
@@ -5,7 +5,7 @@
from
..utils
import
logger
,
get_rng
from
..utils.timer
import
timed_operation
from
..utils.loadcaffe
import
get_caffe_pb
from
.base
import
DataFlow
from
.base
import
RNG
DataFlow
import
random
from
tqdm
import
tqdm
...
...
@@ -31,7 +31,7 @@ else:
Adapters for different data format.
"""
class
HDF5Data
(
DataFlow
):
class
HDF5Data
(
RNG
DataFlow
):
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
"""
...
...
@@ -55,19 +55,18 @@ class HDF5Data(DataFlow):
def
get_data
(
self
):
idxs
=
list
(
range
(
self
.
_size
))
if
self
.
shuffle
:
random
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
yield
[
dp
[
k
]
for
dp
in
self
.
dps
]
class
LMDBData
(
DataFlow
):
class
LMDBData
(
RNG
DataFlow
):
""" Read a lmdb and produce k,v pair """
def
__init__
(
self
,
lmdb_dir
,
shuffle
=
True
):
self
.
_lmdb
=
lmdb
.
open
(
lmdb_dir
,
readonly
=
True
,
lock
=
False
,
map_size
=
1099511627776
*
2
,
max_readers
=
100
)
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
_shuffle
=
shuffle
self
.
rng
=
get_rng
(
self
)
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
if
shuffle
:
self
.
keys
=
self
.
_txn
.
get
(
'__keys__'
)
...
...
@@ -81,8 +80,8 @@ class LMDBData(DataFlow):
pbar
.
update
()
def
reset_state
(
self
):
super
(
LMDBData
,
self
)
.
reset_state
()
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
return
self
.
_size
...
...
@@ -96,8 +95,8 @@ class LMDBData(DataFlow):
yield
[
k
,
v
]
else
:
s
=
self
.
size
()
for
i
in
range
(
s
):
k
=
self
.
rng
.
choice
(
self
.
keys
)
self
.
rng
.
shuffle
(
self
.
keys
)
for
k
in
self
.
keys
:
v
=
self
.
_txn
.
get
(
k
)
yield
[
k
,
v
]
...
...
tensorpack/dataflow/prefetch.py
View file @
2f3b8502
...
...
@@ -35,7 +35,7 @@ class PrefetchProcess(multiprocessing.Process):
self
.
queue
=
queue
def
run
(
self
):
# reset
RNG of
ds so each process will produce different data
# reset
all
ds so each process will produce different data
self
.
ds
.
reset_state
()
while
True
:
for
dp
in
self
.
ds
.
get_data
():
...
...
@@ -73,6 +73,10 @@ class PrefetchData(ProxyDataFlow):
dp
=
self
.
queue
.
get
()
yield
dp
def
reset_state
(
self
):
# do nothing. all ds are reset once and only once in spawned processes
pass
class
PrefetchProcessZMQ
(
multiprocessing
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
):
"""
...
...
@@ -134,6 +138,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
))
yield
dp
def
reset_state
(
self
):
# do nothing. all ds are reset once and only once in spawned processes
pass
def
__del__
(
self
):
# on exit, logger may not be functional anymore
try
:
...
...
tensorpack/dataflow/remote.py
View file @
2f3b8502
...
...
@@ -23,6 +23,7 @@ def serve_data(ds, addr):
socket
.
bind
(
addr
)
ds
=
RepeatedData
(
ds
,
-
1
)
try
:
ds
.
reset_state
()
logger
.
info
(
"Serving data at {}"
.
format
(
addr
))
while
True
:
for
dp
in
ds
.
get_data
():
...
...
tensorpack/tfutils/modelutils.py
View file @
2f3b8502
...
...
@@ -19,7 +19,7 @@ def describe_model():
v
.
name
,
shape
.
as_list
(),
ele
))
size_mb
=
total
*
4
/
1024.0
**
2
msg
.
append
(
"Total param={} ({:01f} MB assuming all float32)"
.
format
(
total
,
size_mb
))
logger
.
info
(
"Model Params: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
logger
.
info
(
"Model Param
eter
s: {}"
.
format
(
'
\n
'
.
join
(
msg
)))
def
get_shape_str
(
tensors
):
...
...
tensorpack/train/trainer.py
View file @
2f3b8502
...
...
@@ -40,6 +40,7 @@ class SimpleTrainer(Trainer):
self
.
init_session_and_coord
()
describe_model
()
# create an infinte data producer
self
.
config
.
dataset
.
reset_state
()
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
main_loop
()
...
...
@@ -62,21 +63,22 @@ class SimpleTrainer(Trainer):
return
func
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
enqueue_op
,
raw_input_var
):
def
__init__
(
self
,
trainer
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
sess
=
trainer
.
sess
self
.
coord
=
trainer
.
coord
self
.
dataflow
=
RepeatedData
(
trainer
.
config
.
dataset
,
-
1
)
self
.
input_vars
=
raw_input_var
self
.
op
=
enqueue_op
self
.
queue
=
queue
self
.
input_vars
=
trainer
.
input_vars
self
.
queue
=
trainer
.
input_queue
self
.
op
=
self
.
queue
.
enqueue
(
self
.
input_vars
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
size_op
=
self
.
queue
.
size
()
self
.
daemon
=
True
def
run
(
self
):
self
.
dataflow
.
reset_state
()
with
self
.
sess
.
as_default
():
try
:
while
True
:
...
...
@@ -155,8 +157,7 @@ class QueueInputTrainer(Trainer):
def
_build_enque_thread
(
self
):
""" create a thread that keeps filling the queue """
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
self
.
input_th
=
EnqueueThread
(
self
,
self
.
input_queue
,
enqueue_op
,
self
.
input_vars
)
self
.
input_th
=
EnqueueThread
(
self
)
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
def
train
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment