Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b06fa732
Commit
b06fa732
authored
Mar 03, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
util updates
parent
b6c75ae5
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
54 additions
and
9 deletions
+54
-9
example_cifar10.py
example_cifar10.py
+1
-3
example_mnist.py
example_mnist.py
+1
-4
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+34
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+4
-1
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+1
-0
tensorpack/utils/symbolic_functions.py
tensorpack/utils/symbolic_functions.py
+13
-0
No files found.
example_cifar10.py
View file @
b06fa732
...
...
@@ -73,9 +73,7 @@ class Model(ModelDesc):
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost
)
# compute the number of failed samples, for ValidationError to use at test time
wrong
=
tf
.
not_equal
(
tf
.
cast
(
tf
.
argmax
(
prob
,
1
),
tf
.
int32
),
label
)
wrong
=
tf
.
cast
(
wrong
,
tf
.
float32
)
wrong
=
prediction_incorrect
(
logits
,
label
)
nr_wrong
=
tf
.
reduce_sum
(
wrong
,
name
=
'wrong'
)
# monitor training error
tf
.
add_to_collection
(
...
...
example_mnist.py
View file @
b06fa732
...
...
@@ -64,9 +64,7 @@ class Model(ModelDesc):
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost
)
# compute the number of failed samples, for ValidationError to use at test time
wrong
=
tf
.
not_equal
(
tf
.
cast
(
tf
.
argmax
(
prob
,
1
),
tf
.
int32
),
label
)
wrong
=
tf
.
cast
(
wrong
,
tf
.
float32
)
wrong
=
prediction_incorrect
(
logits
,
label
)
nr_wrong
=
tf
.
reduce_sum
(
wrong
,
name
=
'wrong'
)
# monitor training error
tf
.
add_to_collection
(
...
...
@@ -90,7 +88,6 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
30
# prepare session
sess_config
=
get_default_sess_config
()
...
...
tensorpack/dataflow/common.py
View file @
b06fa732
...
...
@@ -3,13 +3,14 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
import
random
import
copy
from
six.moves
import
range
from
.base
import
DataFlow
,
ProxyDataFlow
from
..utils
import
*
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'FakeData'
,
'MapData'
,
'MapDataComponent'
,
'RandomChooseData'
]
'MapDataComponent'
,
'RandomChooseData'
,
'RandomMixData'
]
class
BatchData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
...
...
@@ -182,3 +183,35 @@ class RandomChooseData(DataFlow):
yield
next
(
itr
)
except
StopIteration
:
return
class
RandomMixData
(
DataFlow
):
"""
Randomly choose from several dataflow, will eventually exhaust all dataflow.
So it's a perfect mix.
"""
def
__init__
(
self
,
df_lists
):
"""
df_lists: list of dataflow
all DataFlow in df_lists must have size() implemented
"""
self
.
df_lists
=
df_lists
self
.
sizes
=
[
k
.
size
()
for
k
in
self
.
df_lists
]
def
reset_state
(
self
):
for
d
in
self
.
df_lists
:
d
.
reset_state
()
def
size
(
self
):
return
sum
(
self
.
sizes
)
def
get_data
(
self
):
sums
=
np
.
cumsum
(
self
.
sizes
)
idxs
=
np
.
arange
(
self
.
size
())
np
.
random
.
shuffle
(
idxs
)
idxs
=
np
.
array
(
map
(
lambda
x
:
np
.
searchsorted
(
sums
,
x
,
'right'
),
idxs
))
itrs
=
[
k
.
get_data
()
for
k
in
self
.
df_lists
]
assert
idxs
.
max
()
==
len
(
itrs
)
-
1
,
"{}!={}"
.
format
(
idxs
.
max
(),
len
(
itrs
)
-
1
)
for
k
in
idxs
:
yield
next
(
itrs
[
k
])
tensorpack/models/conv2d.py
View file @
b06fa732
...
...
@@ -3,6 +3,7 @@
# File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
import
tensorflow
as
tf
import
math
from
._common
import
*
...
...
@@ -21,6 +22,7 @@ def Conv2D(x, out_channel, kernel_shape,
split: split channels. used in Alexnet
"""
in_shape
=
x
.
get_shape
()
.
as_list
()
num_in
=
np
.
prod
(
in_shape
[
1
:])
in_channel
=
in_shape
[
-
1
]
assert
in_channel
%
split
==
0
assert
out_channel
%
split
==
0
...
...
@@ -31,7 +33,8 @@ def Conv2D(x, out_channel, kernel_shape,
stride
=
shape4d
(
stride
)
if
W_init
is
None
:
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
1e-2
)
#W_init = tf.truncated_normal_initializer(stddev=3e-2)
W_init
=
tf
.
contrib
.
layers
.
xavier_initializer_conv2d
()
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
...
...
tensorpack/utils/concurrency.py
View file @
b06fa732
...
...
@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
threading
import
multiprocessing
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
atexit
...
...
tensorpack/utils/symbolic_functions.py
View file @
b06fa732
...
...
@@ -16,6 +16,19 @@ def one_hot(y, num_labels):
onehot_labels
.
set_shape
([
None
,
num_labels
])
return
tf
.
cast
(
onehot_labels
,
tf
.
float32
)
def
prediction_incorrect
(
logits
,
label
):
"""
logits: batchxN
label: batch
return a binary vector with 1 means incorrect prediction
"""
with
tf
.
op_scope
([
logits
,
label
],
'incorrect'
):
wrong
=
tf
.
not_equal
(
tf
.
argmax
(
logits
,
1
),
tf
.
cast
(
label
,
tf
.
int64
))
wrong
=
tf
.
cast
(
wrong
,
tf
.
float32
)
return
wrong
def
flatten
(
x
):
return
tf
.
reshape
(
x
,
[
-
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment