Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4cde005e
Commit
4cde005e
authored
Apr 07, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update docs & small changes
parent
b5ac2443
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
33 additions
and
23 deletions
+33
-23
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+1
-0
examples/FasterRCNN/basemodel.py
examples/FasterRCNN/basemodel.py
+1
-1
examples/ImageNetModels/imagenet_utils.py
examples/ImageNetModels/imagenet_utils.py
+4
-2
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+9
-10
examples/Saliency/README.md
examples/Saliency/README.md
+1
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+4
-1
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+2
-1
tensorpack/dataflow/parallel.py
tensorpack/dataflow/parallel.py
+6
-4
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+5
-3
No files found.
examples/DeepQNetwork/common.py
View file @
4cde005e
...
@@ -87,6 +87,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
...
@@ -87,6 +87,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
for
_
in
tqdm
(
range
(
nr_eval
),
**
get_tqdm_kwargs
()):
for
_
in
tqdm
(
range
(
nr_eval
),
**
get_tqdm_kwargs
()):
fetch
()
fetch
()
# waiting is necessary, otherwise the estimated mean score is biased
logger
.
info
(
"Waiting for all the workers to finish the last run..."
)
logger
.
info
(
"Waiting for all the workers to finish the last run..."
)
for
k
in
threads
:
for
k
in
threads
:
k
.
stop
()
k
.
stop
()
...
...
examples/FasterRCNN/basemodel.py
View file @
4cde005e
...
@@ -26,7 +26,7 @@ def maybe_freeze_affine(getter, *args, **kwargs):
...
@@ -26,7 +26,7 @@ def maybe_freeze_affine(getter, *args, **kwargs):
def
resnet_argscope
():
def
resnet_argscope
():
with
argscope
([
Conv2D
,
MaxPooling
,
BatchNorm
],
data_format
=
'channels_first'
),
\
with
argscope
([
Conv2D
,
MaxPooling
,
BatchNorm
],
data_format
=
'channels_first'
),
\
argscope
(
Conv2D
,
use_bias
=
False
),
\
argscope
(
Conv2D
,
use_bias
=
False
),
\
argscope
(
BatchNorm
,
training
=
False
,
epsilon
=
0
),
\
argscope
(
BatchNorm
,
training
=
False
),
\
custom_getter_scope
(
maybe_freeze_affine
):
custom_getter_scope
(
maybe_freeze_affine
):
yield
yield
...
...
examples/ImageNetModels/imagenet_utils.py
View file @
4cde005e
...
@@ -150,8 +150,10 @@ class ImageNetModel(ModelDesc):
...
@@ -150,8 +150,10 @@ class ImageNetModel(ModelDesc):
"""
"""
weight_decay_on_bn
=
False
weight_decay_on_bn
=
False
def
__init__
(
self
,
data_format
=
'NCHW'
):
"""
self
.
data_format
=
data_format
Either 'NCHW' or 'NHWC'
"""
data_format
=
'NCHW'
def
inputs
(
self
):
def
inputs
(
self
):
return
[
tf
.
placeholder
(
self
.
image_dtype
,
[
None
,
self
.
image_shape
,
self
.
image_shape
,
3
],
'input'
),
return
[
tf
.
placeholder
(
self
.
image_dtype
,
[
None
,
self
.
image_shape
,
self
.
image_shape
,
3
],
'input'
),
...
...
examples/ResNet/imagenet-resnet.py
View file @
4cde005e
...
@@ -25,9 +25,7 @@ from resnet_model import (
...
@@ -25,9 +25,7 @@ from resnet_model import (
class
Model
(
ImageNetModel
):
class
Model
(
ImageNetModel
):
def
__init__
(
self
,
depth
,
data_format
=
'NCHW'
,
mode
=
'resnet'
):
def
__init__
(
self
,
depth
,
mode
=
'resnet'
):
super
(
Model
,
self
)
.
__init__
(
data_format
)
if
mode
==
'se'
:
if
mode
==
'se'
:
assert
depth
>=
50
assert
depth
>=
50
...
@@ -64,17 +62,17 @@ def get_config(model, fake=False):
...
@@ -64,17 +62,17 @@ def get_config(model, fake=False):
assert
args
.
batch
%
nr_tower
==
0
assert
args
.
batch
%
nr_tower
==
0
batch
=
args
.
batch
//
nr_tower
batch
=
args
.
batch
//
nr_tower
logger
.
info
(
"Running on {} towers. Batch size per tower: {}"
.
format
(
nr_tower
,
batch
))
if
fake
:
if
fake
:
logger
.
info
(
"For benchmark, batch size is fixed to 64 per tower."
)
dataset_train
=
FakeData
(
dataset_train
=
FakeData
(
[[
64
,
224
,
224
,
3
],
[
64
]],
1000
,
random
=
False
,
dtype
=
'uint8'
)
[[
batch
,
224
,
224
,
3
],
[
batch
]],
1000
,
random
=
False
,
dtype
=
'uint8'
)
callbacks
=
[]
callbacks
=
[]
else
:
else
:
logger
.
info
(
"Running on {} towers. Batch size per tower: {}"
.
format
(
nr_tower
,
batch
))
dataset_train
=
get_data
(
'train'
,
batch
)
dataset_train
=
get_data
(
'train'
,
batch
)
dataset_val
=
get_data
(
'val'
,
batch
)
dataset_val
=
get_data
(
'val'
,
batch
)
BASE_LR
=
0.1
*
(
args
.
batch
/
256.0
)
START_LR
=
0.1
BASE_LR
=
START_LR
*
(
args
.
batch
/
256.0
)
callbacks
=
[
callbacks
=
[
ModelSaver
(),
ModelSaver
(),
EstimatedTimeLeft
(),
EstimatedTimeLeft
(),
...
@@ -82,10 +80,10 @@ def get_config(model, fake=False):
...
@@ -82,10 +80,10 @@ def get_config(model, fake=False):
'learning_rate'
,
[(
30
,
BASE_LR
*
1e-1
),
(
60
,
BASE_LR
*
1e-2
),
'learning_rate'
,
[(
30
,
BASE_LR
*
1e-1
),
(
60
,
BASE_LR
*
1e-2
),
(
90
,
BASE_LR
*
1e-3
),
(
100
,
BASE_LR
*
1e-4
)]),
(
90
,
BASE_LR
*
1e-3
),
(
100
,
BASE_LR
*
1e-4
)]),
]
]
if
BASE_LR
>
0.1
:
if
BASE_LR
>
START_LR
:
callbacks
.
append
(
callbacks
.
append
(
ScheduledHyperParamSetter
(
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
0
,
0.1
),
(
3
,
BASE_LR
)],
interp
=
'linear'
))
'learning_rate'
,
[(
0
,
START_LR
),
(
5
,
BASE_LR
)],
interp
=
'linear'
))
infs
=
[
ClassificationError
(
'wrong-top1'
,
'val-error-top1'
),
infs
=
[
ClassificationError
(
'wrong-top1'
,
'val-error-top1'
),
ClassificationError
(
'wrong-top5'
,
'val-error-top5'
)]
ClassificationError
(
'wrong-top5'
,
'val-error-top5'
)]
...
@@ -126,7 +124,8 @@ if __name__ == '__main__':
...
@@ -126,7 +124,8 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
model
=
Model
(
args
.
depth
,
args
.
data_format
,
args
.
mode
)
model
=
Model
(
args
.
depth
,
args
.
mode
)
model
.
data_format
=
args
.
data_format
if
args
.
eval
:
if
args
.
eval
:
batch
=
128
# something that can run on one gpu
batch
=
128
# something that can run on one gpu
ds
=
get_data
(
'val'
,
batch
)
ds
=
get_data
(
'val'
,
batch
)
...
...
examples/Saliency/README.md
View file @
4cde005e
...
@@ -39,7 +39,7 @@ Usage:
...
@@ -39,7 +39,7 @@ Usage:
./CAM-resnet.py
--data
/path/to/imagenet
[
--load
ImageNet-ResNet18-Preact.npz]
[
--gpu
0,1,2,3]
./CAM-resnet.py
--data
/path/to/imagenet
[
--load
ImageNet-ResNet18-Preact.npz]
[
--gpu
0,1,2,3]
```
```
Pretrained and fine-tuned ResNet can be downloaded
Pretrained and fine-tuned ResNet can be downloaded
[
here
](
http://models.tensorpack.com/ResNet/
)
and
[
here
](
http://models.tensorpack.com/Visualization
/
)
.
in the
[
model zoo
](
http://models.tensorpack.com
/
)
.
2.
Generate CAM on ImageNet validation set:
2.
Generate CAM on ImageNet validation set:
```
bash
```
bash
...
...
tensorpack/dataflow/common.py
View file @
4cde005e
...
@@ -301,7 +301,7 @@ class MapDataComponent(MapData):
...
@@ -301,7 +301,7 @@ class MapDataComponent(MapData):
r
=
func
(
dp
[
index
])
r
=
func
(
dp
[
index
])
if
r
is
None
:
if
r
is
None
:
return
None
return
None
dp
=
copy
(
dp
)
# shallow copy to avoid modifying the list
dp
=
list
(
dp
)
# shallow copy to avoid modifying the list
dp
[
index
]
=
r
dp
[
index
]
=
r
return
dp
return
dp
super
(
MapDataComponent
,
self
)
.
__init__
(
ds
,
f
)
super
(
MapDataComponent
,
self
)
.
__init__
(
ds
,
f
)
...
@@ -606,6 +606,9 @@ class CacheData(ProxyDataFlow):
...
@@ -606,6 +606,9 @@ class CacheData(ProxyDataFlow):
"""
"""
Cache the first pass of a DataFlow completely in memory,
Cache the first pass of a DataFlow completely in memory,
and produce from the cache thereafter.
and produce from the cache thereafter.
NOTE: The user should not stop the iterator before it has reached the end.
Otherwise the cache may be incomplete.
"""
"""
def
__init__
(
self
,
ds
,
shuffle
=
False
):
def
__init__
(
self
,
ds
,
shuffle
=
False
):
"""
"""
...
...
tensorpack/dataflow/imgaug/imgproc.py
View file @
4cde005e
...
@@ -268,7 +268,8 @@ class Lighting(ImageAugmentor):
...
@@ -268,7 +268,8 @@ class Lighting(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
assert
img
.
shape
[
2
]
==
3
assert
img
.
shape
[
2
]
==
3
return
self
.
rng
.
randn
(
3
)
*
self
.
std
ret
=
self
.
rng
.
randn
(
3
)
*
self
.
std
return
ret
.
astype
(
'float32'
)
def
_augment
(
self
,
img
,
v
):
def
_augment
(
self
,
img
,
v
):
old_dtype
=
img
.
dtype
old_dtype
=
img
.
dtype
...
...
tensorpack/dataflow/parallel.py
View file @
4cde005e
...
@@ -403,12 +403,13 @@ class PlasmaPutData(ProxyDataFlow):
...
@@ -403,12 +403,13 @@ class PlasmaPutData(ProxyDataFlow):
Experimental.
Experimental.
"""
"""
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
,
socket
=
"/tmp/plasma"
):
self
.
_socket
=
socket
super
(
PlasmaPutData
,
self
)
.
__init__
(
ds
)
super
(
PlasmaPutData
,
self
)
.
__init__
(
ds
)
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
PlasmaPutData
,
self
)
.
reset_state
()
super
(
PlasmaPutData
,
self
)
.
reset_state
()
self
.
client
=
plasma
.
connect
(
"/tmp/plasma"
,
""
,
0
)
self
.
client
=
plasma
.
connect
(
self
.
_socket
,
""
,
0
)
def
get_data
(
self
):
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
...
@@ -421,12 +422,13 @@ class PlasmaGetData(ProxyDataFlow):
...
@@ -421,12 +422,13 @@ class PlasmaGetData(ProxyDataFlow):
Take plasma object id from a DataFlow, and retrieve it from plasma shared
Take plasma object id from a DataFlow, and retrieve it from plasma shared
memory object store.
memory object store.
"""
"""
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
,
socket
=
"/tmp/plasma"
):
self
.
_socket
=
socket
super
(
PlasmaGetData
,
self
)
.
__init__
(
ds
)
super
(
PlasmaGetData
,
self
)
.
__init__
(
ds
)
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
PlasmaGetData
,
self
)
.
reset_state
()
super
(
PlasmaGetData
,
self
)
.
reset_state
()
self
.
client
=
plasma
.
connect
(
"/tmp/plasma"
,
""
,
0
)
self
.
client
=
plasma
.
connect
(
self
.
_socket
,
""
,
0
)
def
get_data
(
self
):
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
...
...
tensorpack/tfutils/common.py
View file @
4cde005e
...
@@ -38,8 +38,10 @@ def get_default_sess_config(mem_fraction=0.99):
...
@@ -38,8 +38,10 @@ def get_default_sess_config(mem_fraction=0.99):
# Didn't see much difference.
# Didn't see much difference.
conf
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.99
conf
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.99
if
get_tf_version_number
()
>=
1.2
:
conf
.
gpu_options
.
force_gpu_compatible
=
True
# This hurt performance of large data pipeline:
# https://github.com/tensorflow/benchmarks/commit/1528c46499cdcff669b5d7c006b7b971884ad0e6
# conf.gpu_options.force_gpu_compatible = True
conf
.
gpu_options
.
allow_growth
=
True
conf
.
gpu_options
.
allow_growth
=
True
...
@@ -47,7 +49,7 @@ def get_default_sess_config(mem_fraction=0.99):
...
@@ -47,7 +49,7 @@ def get_default_sess_config(mem_fraction=0.99):
# conf.graph_options.rewrite_options.memory_optimization = \
# conf.graph_options.rewrite_options.memory_optimization = \
# rwc.RewriterConfig.HEURISTICS
# rwc.RewriterConfig.HEURISTICS
# May hurt performance
# May hurt performance
?
# conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
# conf.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
# conf.graph_options.place_pruned_graph = True
# conf.graph_options.place_pruned_graph = True
return
conf
return
conf
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment