Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f5d5d4c2
Commit
f5d5d4c2
authored
Jul 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix SimpleTrainer and use longer test survival limit
parent
3b7e7c55
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
17 additions
and
24 deletions
+17
-24
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+2
-1
tensorpack/train/simple.py
tensorpack/train/simple.py
+3
-4
tests/case_script.py
tests/case_script.py
+5
-5
tests/test_char_rnn.py
tests/test_char_rnn.py
+1
-1
tests/test_infogan.py
tests/test_infogan.py
+1
-1
tests/test_mnist.py
tests/test_mnist.py
+1
-1
tests/test_resnet.py
tests/test_resnet.py
+3
-10
No files found.
tensorpack/train/base.py
View file @
f5d5d4c2
...
@@ -64,7 +64,7 @@ class Trainer(object):
...
@@ -64,7 +64,7 @@ class Trainer(object):
self
.
monitors
=
[]
self
.
monitors
=
[]
self
.
_epoch_num
=
None
self
.
_epoch_num
=
None
self
.
_setup
()
# subclass will setup the graph
self
.
_setup
()
# subclass will setup the graph
and InputSource
@
property
@
property
def
epoch_num
(
self
):
def
epoch_num
(
self
):
...
...
tensorpack/train/distributed.py
View file @
f5d5d4c2
...
@@ -63,7 +63,6 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
...
@@ -63,7 +63,6 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
self
.
_input_source
=
config
.
data
self
.
_input_source
=
config
.
data
self
.
is_chief
=
(
self
.
task_index
==
0
and
self
.
job_name
==
'worker'
)
self
.
is_chief
=
(
self
.
task_index
==
0
and
self
.
job_name
==
'worker'
)
super
(
DistributedReplicatedTrainer
,
self
)
.
__init__
(
config
)
worker_prefix
=
'/job:worker/task:
%
s'
%
self
.
task_index
worker_prefix
=
'/job:worker/task:
%
s'
%
self
.
task_index
self
.
param_server_device
=
tf
.
train
.
replica_device_setter
(
self
.
param_server_device
=
tf
.
train
.
replica_device_setter
(
...
@@ -79,6 +78,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
...
@@ -79,6 +78,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
self
.
sync_queue_devices
=
[
'/job:ps/task:
%
s/cpu:0'
%
i
for
i
in
range
(
self
.
num_ps
)]
self
.
sync_queue_devices
=
[
'/job:ps/task:
%
s/cpu:0'
%
i
for
i
in
range
(
self
.
num_ps
)]
self
.
sync_queue_counter
=
0
self
.
sync_queue_counter
=
0
super
(
DistributedReplicatedTrainer
,
self
)
.
__init__
(
config
)
@
staticmethod
@
staticmethod
def
_average_grads
(
tower_grads
,
devices
):
def
_average_grads
(
tower_grads
,
devices
):
"""
"""
...
...
tensorpack/train/simple.py
View file @
f5d5d4c2
...
@@ -23,11 +23,9 @@ class SimpleTrainer(Trainer):
...
@@ -23,11 +23,9 @@ class SimpleTrainer(Trainer):
Args:
Args:
config (TrainConfig): the training config.
config (TrainConfig): the training config.
"""
"""
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
assert
len
(
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"Got nr_tower={}, but doesn't support multigpu!"
\
"Got nr_tower={}, but doesn't support multigpu!"
\
" Use Sync/AsyncMultiGPUTrainer instead."
.
format
(
len
(
self
.
config
.
tower
))
" Use Sync/AsyncMultiGPUTrainer instead."
.
format
(
len
(
config
.
tower
))
if
config
.
dataflow
is
None
:
if
config
.
dataflow
is
None
:
self
.
_input_source
=
config
.
data
self
.
_input_source
=
config
.
data
...
@@ -35,6 +33,7 @@ class SimpleTrainer(Trainer):
...
@@ -35,6 +33,7 @@ class SimpleTrainer(Trainer):
self
.
_input_source
=
FeedInput
(
config
.
dataflow
)
self
.
_input_source
=
FeedInput
(
config
.
dataflow
)
logger
.
warn
(
"FeedInput is slow (and this is the default of SimpleTrainer). "
logger
.
warn
(
"FeedInput is slow (and this is the default of SimpleTrainer). "
"Consider QueueInput or other InputSource instead."
)
"Consider QueueInput or other InputSource instead."
)
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
def
run_step
(
self
):
def
run_step
(
self
):
self
.
hooked_sess
.
run
(
self
.
train_op
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
...
...
tests/case_script.py
View file @
f5d5d4c2
...
@@ -21,12 +21,12 @@ class PythonScript(threading.Thread):
...
@@ -21,12 +21,12 @@ class PythonScript(threading.Thread):
p: process handle
p: process handle
timeout (int): timeout in seconds
timeout (int): timeout in seconds
"""
"""
def
__init__
(
self
,
cmd
,
timeout
=
10
):
def
__init__
(
self
,
cmd
,
timeout
):
"""Prepare a python script
"""Prepare a python script
Args:
Args:
cmd (
TYPE
): command to execute the example with all flags (including python)
cmd (
str
): command to execute the example with all flags (including python)
timeout (int
, optional
): time in seconds the script has to survive
timeout (int): time in seconds the script has to survive
"""
"""
threading
.
Thread
.
__init__
(
self
)
threading
.
Thread
.
__init__
(
self
)
self
.
cmd
=
cmd
self
.
cmd
=
cmd
...
@@ -51,7 +51,7 @@ class PythonScript(threading.Thread):
...
@@ -51,7 +51,7 @@ class PythonScript(threading.Thread):
self
.
join
()
self
.
join
()
else
:
else
:
# something unexpected happend here, this script was supposed to survive at least the timeout
# something unexpected happend here, this script was supposed to survive at least the timeout
if
len
(
self
.
err
)
is
not
0
:
if
len
(
self
.
err
)
>
0
:
output
=
u"STDOUT:
\n\n\n
"
+
self
.
out
.
decode
(
'utf-8'
)
output
=
u"STDOUT:
\n\n\n
"
+
self
.
out
.
decode
(
'utf-8'
)
output
+=
u"
\n\n\n
STDERR:
\n\n\n
"
+
self
.
err
.
decode
(
'utf-8'
)
output
+=
u"
\n\n\n
STDERR:
\n\n\n
"
+
self
.
err
.
decode
(
'utf-8'
)
raise
AssertionError
(
output
)
raise
AssertionError
(
output
)
...
@@ -70,7 +70,7 @@ class TestPythonScript(unittest.TestCase):
...
@@ -70,7 +70,7 @@ class TestPythonScript(unittest.TestCase):
if
os
.
path
.
isdir
(
os
.
path
.
join
(
"train_log"
,
script
)):
if
os
.
path
.
isdir
(
os
.
path
.
join
(
"train_log"
,
script
)):
shutil
.
rmtree
(
os
.
path
.
join
(
"train_log"
,
script
))
shutil
.
rmtree
(
os
.
path
.
join
(
"train_log"
,
script
))
def
assertSurvive
(
self
,
script
,
args
=
None
,
timeout
=
1
0
):
# noqa
def
assertSurvive
(
self
,
script
,
args
=
None
,
timeout
=
2
0
):
# noqa
cmd
=
"python{} {}"
.
format
(
sys
.
version_info
.
major
,
script
)
cmd
=
"python{} {}"
.
format
(
sys
.
version_info
.
major
,
script
)
if
args
:
if
args
:
cmd
+=
" "
+
" "
.
join
(
args
)
cmd
+=
" "
+
" "
.
join
(
args
)
...
...
tests/test_char_rnn.py
View file @
f5d5d4c2
...
@@ -20,7 +20,7 @@ class CharRNNTest(TestPythonScript):
...
@@ -20,7 +20,7 @@ class CharRNNTest(TestPythonScript):
f
.
write
(
random_content
())
f
.
write
(
random_content
())
def
test
(
self
):
def
test
(
self
):
self
.
assertSurvive
(
self
.
script
,
args
=
[
'
--gpu 0'
,
'train'
],
timeout
=
10
)
self
.
assertSurvive
(
self
.
script
,
args
=
[
'
train'
]
)
def
tearDown
(
self
):
def
tearDown
(
self
):
super
(
CharRNNTest
,
self
)
.
tearDown
()
super
(
CharRNNTest
,
self
)
.
tearDown
()
...
...
tests/test_infogan.py
View file @
f5d5d4c2
...
@@ -8,4 +8,4 @@ class InfoGANTest(TestPythonScript):
...
@@ -8,4 +8,4 @@ class InfoGANTest(TestPythonScript):
return
'../examples/GAN/InfoGAN-mnist.py'
return
'../examples/GAN/InfoGAN-mnist.py'
def
test
(
self
):
def
test
(
self
):
self
.
assertSurvive
(
self
.
script
,
args
=
None
,
timeout
=
10
)
self
.
assertSurvive
(
self
.
script
,
args
=
None
)
tests/test_mnist.py
View file @
f5d5d4c2
...
@@ -8,4 +8,4 @@ class MnistTest(TestPythonScript):
...
@@ -8,4 +8,4 @@ class MnistTest(TestPythonScript):
return
'../examples/mnist-convnet.py'
return
'../examples/mnist-convnet.py'
def
test
(
self
):
def
test
(
self
):
self
.
assertSurvive
(
self
.
script
,
args
=
None
,
timeout
=
10
)
self
.
assertSurvive
(
self
.
script
,
args
=
None
)
tests/test_resnet.py
View file @
f5d5d4c2
from
case_script
import
TestPythonScript
from
case_script
import
TestPythonScript
import
os
import
shutil
class
ResnetTest
(
TestPythonScript
):
class
ResnetTest
(
TestPythonScript
):
@
property
@
property
def
script
(
self
):
def
script
(
self
):
return
'../examples/ResNet/imagenet-resnet.py'
return
'../examples/ResNet/imagenet-resnet.py'
def
test
(
self
):
def
test
(
self
):
self
.
assertSurvive
(
self
.
script
,
args
=
[
'--data .'
,
self
.
assertSurvive
(
'--gpu 0'
,
'--fake'
,
'--data_format NHWC'
],
timeout
=
10
)
self
.
script
,
args
=
[
'--fake'
,
'--data_format NHWC'
],
timeout
=
20
)
def
tearDown
(
self
):
super
(
ResnetTest
,
self
)
.
tearDown
()
if
os
.
path
.
isdir
(
'ilsvrc'
):
shutil
.
rmtree
(
'ilsvrc'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment