Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8759e324
Commit
8759e324
authored
Feb 24, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
reset rng
parent
f4507d45
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
22 deletions
+44
-22
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+15
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+24
-18
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+1
-0
tensorpack/train/config.py
tensorpack/train/config.py
+2
-2
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+2
-1
No files found.
tensorpack/dataflow/base.py
View file @
8759e324
...
...
@@ -6,7 +6,7 @@
from
abc
import
abstractmethod
,
ABCMeta
__all__
=
[
'DataFlow'
]
__all__
=
[
'DataFlow'
,
'ProxyDataFlow'
]
class
DataFlow
(
object
):
__metaclass__
=
ABCMeta
...
...
@@ -23,4 +23,18 @@ class DataFlow(object):
"""
raise
NotImplementedError
()
def
reset_state
(
self
):
"""
Reset state of the dataflow (usually the random seed)
"""
pass
class
ProxyDataFlow
(
DataFlow
):
def
__init__
(
self
,
ds
):
self
.
ds
=
ds
def
reset_state
(
self
):
self
.
ds
.
reset_state
()
def
size
(
self
):
return
self
.
ds
.
size
()
tensorpack/dataflow/common.py
View file @
8759e324
...
...
@@ -5,14 +5,15 @@
import
numpy
as
np
import
copy
from
.base
import
DataFlow
from
.base
import
DataFlow
,
ProxyDataFlow
from
.imgaug
import
AugmentorList
,
Image
from
..utils
import
*
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'FakeData'
,
'MapData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'AugmentImageComponent'
]
class
BatchData
(
DataFlow
):
class
BatchData
(
Proxy
DataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
"""
Group data in ds into batches
...
...
@@ -20,7 +21,7 @@ class BatchData(DataFlow):
remainder: whether to return the remaining data smaller than a batch_size.
if set True, will possibly return a data point of a smaller 1st dimension
"""
s
elf
.
ds
=
ds
s
uper
(
BatchData
,
self
)
.
__init__
(
ds
)
if
not
remainder
:
assert
batch_size
<=
ds
.
size
()
self
.
batch_size
=
batch_size
...
...
@@ -60,10 +61,10 @@ class BatchData(DataFlow):
np
.
array
([
x
[
k
]
for
x
in
data_holder
],
dtype
=
tp
))
return
result
class
FixedSizeData
(
DataFlow
):
class
FixedSizeData
(
Proxy
DataFlow
):
""" generate data from another dataflow, but with a fixed epoch size"""
def
__init__
(
self
,
ds
,
size
):
s
elf
.
ds
=
ds
s
uper
(
FixedSizeData
,
self
)
.
__init__
(
ds
)
self
.
_size
=
size
self
.
itr
=
None
...
...
@@ -86,13 +87,13 @@ class FixedSizeData(DataFlow):
if
cnt
==
self
.
_size
:
return
class
RepeatedData
(
DataFlow
):
class
RepeatedData
(
Proxy
DataFlow
):
""" repeat another dataflow for certain times
if nr == -1, repeat infinitely many times
"""
def
__init__
(
self
,
ds
,
nr
):
self
.
nr
=
nr
s
elf
.
ds
=
ds
s
uper
(
RepeatedData
,
self
)
.
__init__
(
ds
)
def
size
(
self
):
if
self
.
nr
==
-
1
:
...
...
@@ -117,37 +118,35 @@ class FakeData(DataFlow):
"""
self
.
shapes
=
shapes
self
.
_size
=
size
self
.
rng
=
get_rng
(
self
)
def
size
(
self
):
return
self
.
_size
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
get_data
(
self
):
for
_
in
xrange
(
self
.
_size
):
yield
[
np
.
random
.
random
(
k
)
for
k
in
self
.
shapes
]
yield
[
self
.
rng
.
random_sample
(
k
)
for
k
in
self
.
shapes
]
class
MapData
(
DataFlow
):
class
MapData
(
Proxy
DataFlow
):
""" Map a function to the datapoint"""
def
__init__
(
self
,
ds
,
func
):
s
elf
.
ds
=
ds
s
uper
(
MapData
,
self
)
.
__init_
(
ds
)
self
.
func
=
func
def
size
(
self
):
return
self
.
ds
.
size
()
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
yield
self
.
func
(
dp
)
class
MapDataComponent
(
DataFlow
):
class
MapDataComponent
(
Proxy
DataFlow
):
""" Apply a function to the given index in the datapoint"""
def
__init__
(
self
,
ds
,
func
,
index
=
0
):
s
elf
.
ds
=
ds
s
uper
(
MapDataComponent
,
self
)
.
__init__
(
ds
)
self
.
func
=
func
self
.
index
=
index
def
size
(
self
):
return
self
.
ds
.
size
()
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
dp
=
copy
.
deepcopy
(
dp
)
# avoid modifying the original dp
...
...
@@ -169,6 +168,13 @@ class RandomChooseData(DataFlow):
prob
=
1.0
/
len
(
df_lists
)
self
.
df_lists
=
[(
k
,
prob
)
for
k
in
df_lists
]
def
reset_state
(
self
):
for
d
in
self
.
df_lists
:
if
isinstance
(
d
,
tuple
):
d
[
0
]
.
reset_state
()
else
:
d
.
reset_state
()
def
get_data
(
self
):
itrs
=
[
v
[
0
]
.
get_data
()
for
v
in
self
.
df_lists
]
probs
=
np
.
array
([
v
[
1
]
for
v
in
self
.
df_lists
])
...
...
tensorpack/dataflow/prefetch.py
View file @
8759e324
...
...
@@ -23,6 +23,7 @@ class PrefetchProcess(multiprocessing.Process):
self
.
queue
=
queue
def
run
(
self
):
self
.
ds
.
reset_state
()
try
:
for
dp
in
self
.
ds
.
get_data
():
self
.
queue
.
put
(
dp
)
...
...
tensorpack/train/config.py
View file @
8759e324
...
...
@@ -30,7 +30,7 @@ class TrainConfig(object):
initialize variables of a session. default to a new session.
model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch.
default to dataset.size()
in each epoch.
max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
"""
...
...
@@ -49,7 +49,7 @@ class TrainConfig(object):
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
NewSession
())
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
step_per_epoch
=
int
(
kwargs
.
pop
(
'step_per_epoch'
,
self
.
dataset
.
size
()
))
self
.
step_per_epoch
=
int
(
kwargs
.
pop
(
'step_per_epoch'
))
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
100
))
assert
self
.
step_per_epoch
>
0
and
self
.
max_epoch
>
0
self
.
nr_tower
=
int
(
kwargs
.
pop
(
'nr_tower'
,
1
))
...
...
tensorpack/utils/__init__.py
View file @
8759e324
...
...
@@ -87,4 +87,5 @@ def get_global_step():
get_global_step_var
())
def
get_rng
(
self
):
return
np
.
random
.
RandomState
()
seed
=
(
id
(
self
)
+
os
.
getpid
())
%
4294967295
return
np
.
random
.
RandomState
(
seed
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment