Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
26e609f8
Commit
26e609f8
authored
Jan 03, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
minor changes
parent
c81ea087
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
29 additions
and
10 deletions
+29
-10
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+6
-3
examples/ResNet/imagenet_utils.py
examples/ResNet/imagenet_utils.py
+19
-1
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+0
-3
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+2
-2
tensorpack/tfutils/scope_utils.py
tensorpack/tfutils/scope_utils.py
+2
-1
No files found.
examples/ResNet/imagenet-resnet.py
View file @
26e609f8
...
...
@@ -94,7 +94,7 @@ def get_config(model, fake=False):
model
=
model
,
dataflow
=
dataset_train
,
callbacks
=
callbacks
,
steps_per_epoch
=
5000
,
steps_per_epoch
=
100
if
args
.
fake
else
5000
,
# 5000 ~= 1.28M / TOTAL_BATCH_SIZE
max_epoch
=
110
,
nr_tower
=
nr_tower
)
...
...
@@ -123,6 +123,9 @@ if __name__ == '__main__':
batch
=
128
# something that can run on one gpu
ds
=
get_data
(
'val'
,
batch
)
eval_on_ILSVRC12
(
model
,
get_model_loader
(
args
.
load
),
ds
)
else
:
if
args
.
fake
:
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'tmp'
),
'd'
)
else
:
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'imagenet-resnet-d'
+
str
(
args
.
depth
)))
...
...
examples/ResNet/imagenet_utils.py
View file @
26e609f8
...
...
@@ -92,7 +92,7 @@ def get_imagenet_dataflow(
assert
datadir
is
not
None
assert
isinstance
(
augmentors
,
list
)
isTrain
=
name
==
'train'
cpu
=
min
(
3
0
,
multiprocessing
.
cpu_count
())
cpu
=
min
(
4
0
,
multiprocessing
.
cpu_count
())
if
isTrain
:
ds
=
dataset
.
ILSVRC12
(
datadir
,
name
,
shuffle
=
True
)
ds
=
AugmentImageComponent
(
ds
,
augmentors
,
copy
=
False
)
...
...
@@ -213,3 +213,21 @@ class ImageNetModel(ModelDesc):
wrong
=
prediction_incorrect
(
logits
,
label
,
5
,
name
=
'wrong-top5'
)
add_moving_summary
(
tf
.
reduce_mean
(
wrong
,
name
=
'train-error-top5'
))
return
loss
if
__name__
==
'__main__'
:
import
argparse
from
tensorpack.dataflow
import
TestDataSpeed
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--data'
,
required
=
True
)
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
32
)
args
=
parser
.
parse_args
()
augs
=
fbresnet_augmentor
(
False
)
augs
=
[
imgaug
.
ResizeShortestEdge
(
256
),
imgaug
.
CenterCrop
(
224
)
]
df
=
get_imagenet_dataflow
(
args
.
data
,
'train'
,
args
.
batch
,
augs
)
TestDataSpeed
(
df
)
.
start
()
tensorpack/contrib/keras.py
View file @
26e609f8
...
...
@@ -17,7 +17,6 @@ from ..callbacks import (
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.scope_utils
import
cached_name_scope
# from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from
..tfutils.summary
import
add_moving_summary
from
..utils.gpu
import
get_nr_gpu
...
...
@@ -108,8 +107,6 @@ def setup_keras_trainer(
nr_inputs
=
len
(
inputs_desc
)
def
get_cost
(
*
inputs
):
assert
len
(
inputs
)
==
len
(
inputs_desc
)
+
len
(
targets_desc
),
\
"Input source size {} != {} + {}"
.
format
(
len
(
inputs
),
len
(
inputs_desc
),
len
(
targets_desc
))
ctx
=
get_current_tower_context
()
input_tensors
=
list
(
inputs
[:
nr_inputs
])
target_tensors
=
list
(
inputs
[
nr_inputs
:])
...
...
tensorpack/dataflow/common.py
View file @
26e609f8
...
...
@@ -263,7 +263,7 @@ class MapData(ProxyDataFlow):
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
ret
=
self
.
func
(
dp
)
ret
=
self
.
func
(
copy
(
dp
))
# shallow copy the list
if
ret
is
not
None
:
yield
ret
...
...
@@ -292,7 +292,7 @@ class MapDataComponent(MapData):
r
=
func
(
dp
[
index
])
if
r
is
None
:
return
None
dp
=
copy
(
dp
)
# avoid modifying the list
dp
=
copy
(
dp
)
#
shallow copy to
avoid modifying the list
dp
[
index
]
=
r
return
dp
super
(
MapDataComponent
,
self
)
.
__init__
(
ds
,
f
)
...
...
tensorpack/tfutils/scope_utils.py
View file @
26e609f8
...
...
@@ -96,6 +96,7 @@ def cached_name_scope(name, top_level=True):
"""
if
not
top_level
:
current_ns
=
tf
.
get_default_graph
()
.
get_name_scope
()
if
current_ns
:
name
=
current_ns
+
'/'
+
name
ns
=
_get_cached_ns
(
name
)
with
tf
.
name_scope
(
ns
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment