Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e63fe50f
Commit
e63fe50f
authored
Oct 25, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix some examples
parent
22b91be9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
4 deletions
+8
-4
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+1
-1
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+2
-1
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+5
-2
No files found.
examples/DoReFa-Net/alexnet-dorefa.py
View file @
e63fe50f
...
...
@@ -227,7 +227,7 @@ def get_data(dataset_name):
ds
=
AugmentImageComponent
(
ds
,
augmentors
,
copy
=
False
)
ds
=
BatchData
(
ds
,
BATCH_SIZE
,
remainder
=
not
isTrain
)
if
isTrain
:
ds
=
PrefetchDataZMQ
(
ds
,
min
(
12
,
multiprocessing
.
cpu_count
()))
ds
=
PrefetchDataZMQ
(
ds
,
min
(
25
,
multiprocessing
.
cpu_count
()))
return
ds
...
...
examples/FasterRCNN/train.py
View file @
e63fe50f
...
...
@@ -223,12 +223,13 @@ class EvalCallback(Callback):
def
_setup_graph
(
self
):
self
.
pred
=
self
.
trainer
.
get_predictor
([
'image'
],
[
'fastrcnn_fg_probs'
,
'fastrcnn_fg_boxes'
])
self
.
df
=
PrefetchDataZMQ
(
get_eval_dataflow
(),
1
)
get_tf_nms
()
# just to make sure the nms part of graph is created
def
_before_train
(
self
):
EVAL_TIMES
=
5
# eval 5 times during training
interval
=
self
.
trainer
.
max_epoch
//
(
EVAL_TIMES
+
1
)
self
.
epochs_to_eval
=
set
([
interval
*
k
for
k
in
range
(
1
,
EVAL_TIMES
)])
self
.
epochs_to_eval
.
add
(
self
.
trainer
.
max_epoch
)
get_tf_nms
()
# just to make sure the nms part of graph is created
def
_eval
(
self
):
all_results
=
eval_on_dataflow
(
self
.
df
,
lambda
img
:
detect_one_image
(
img
,
self
.
pred
))
...
...
examples/ResNet/imagenet-resnet.py
View file @
e63fe50f
...
...
@@ -9,10 +9,12 @@ import os
import
tensorflow
as
tf
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
logger
,
QueueInput
from
tensorpack.models
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.train
import
TrainConfig
,
SyncMultiGPUTrainerParameterServer
from
tensorpack.train
import
(
TrainConfig
,
SyncMultiGPUTrainerParameterServer
,
launch_train_with_config
)
from
tensorpack.dataflow
import
imgaug
,
FakeData
from
tensorpack.tfutils
import
argscope
,
get_model_loader
from
tensorpack.utils.gpu
import
get_nr_gpu
...
...
@@ -132,4 +134,5 @@ if __name__ == '__main__':
config
=
get_config
(
model
,
fake
=
args
.
fake
)
if
args
.
load
:
config
.
session_init
=
get_model_loader
(
args
.
load
)
SyncMultiGPUTrainerParameterServer
(
config
)
.
train
()
trainer
=
SyncMultiGPUTrainerParameterServer
(
max
(
get_nr_gpu
(),
1
))
launch_train_with_config
(
config
,
trainer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment