Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ef4a15ca
Commit
ef4a15ca
authored
Mar 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
atexit for prefetch
parent
8d1ad775
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
52 additions
and
47 deletions
+52
-47
example_mnist.py
example_mnist.py
+1
-0
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+1
-2
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+3
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+29
-15
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+18
-29
No files found.
example_mnist.py
View file @
ef4a15ca
...
...
@@ -90,6 +90,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
30
# prepare session
sess_config
=
get_default_sess_config
()
...
...
tensorpack/callbacks/validation_callback.py
View file @
ef4a15ca
...
...
@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
itertools
from
tqdm
import
tqdm
from
abc
import
ABCMeta
,
abstractmethod
from
six.moves
import
zip
...
...
@@ -80,7 +79,7 @@ class ValidationStatPrinter(ValidationCallback):
stats
=
np
.
mean
(
stats
,
axis
=
0
)
assert
len
(
stats
)
==
len
(
self
.
vars_to_print
)
for
stat
,
var
in
itertools
.
i
zip
(
stats
,
self
.
vars_to_print
):
for
stat
,
var
in
zip
(
stats
,
self
.
vars_to_print
):
name
=
var
.
name
.
replace
(
':0'
,
''
)
self
.
trainer
.
summary_writer
.
add_summary
(
create_summary
(
'{}_{}'
.
format
(
self
.
prefix
,
name
),
stat
),
self
.
global_step
)
...
...
tensorpack/dataflow/prefetch.py
View file @
ef4a15ca
...
...
@@ -2,9 +2,10 @@
# File: prefetch.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
multiprocessing
from
.base
import
DataFlow
import
multiprocessing
from
..utils.concurrency
import
ensure_procs_terminate
__all__
=
[
'PrefetchData'
]
...
...
@@ -45,6 +46,7 @@ class PrefetchData(DataFlow):
def
get_data
(
self
):
queue
=
multiprocessing
.
Queue
(
self
.
nr_prefetch
)
procs
=
[
PrefetchProcess
(
self
.
ds
,
queue
)
for
_
in
range
(
self
.
nr_proc
)]
ensure_procs_terminate
(
procs
)
[
x
.
start
()
for
x
in
procs
]
end_cnt
=
0
...
...
tensorpack/train/trainer.py
View file @
ef4a15ca
...
...
@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
threading
import
copy
import
re
from
six.moves
import
zip
...
...
@@ -10,25 +11,10 @@ from six.moves import zip
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..utils
import
*
from
..utils.concurrency
import
EnqueueThread
from
..utils.summary
import
summary_moving_average
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
'start_train'
]
def
scale_grads
(
grads
,
multiplier
):
ret
=
[]
for
grad
,
var
in
grads
:
varname
=
var
.
name
for
regex
,
val
in
multiplier
:
if
re
.
search
(
regex
,
varname
):
logger
.
info
(
"Apply lr multiplier {} for {}"
.
format
(
val
,
varname
))
ret
.
append
((
grad
*
val
,
var
))
break
else
:
ret
.
append
((
grad
,
var
))
return
ret
class
SimpleTrainer
(
Trainer
):
def
run_step
(
self
):
data
=
next
(
self
.
data_producer
)
...
...
@@ -61,6 +47,34 @@ class SimpleTrainer(Trainer):
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
self
.
_process_summary
(
summary_str
)
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
enqueue_op
,
raw_input_var
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
sess
=
trainer
.
sess
self
.
coord
=
trainer
.
coord
self
.
dataflow
=
trainer
.
config
.
dataset
self
.
input_vars
=
raw_input_var
self
.
op
=
enqueue_op
self
.
queue
=
queue
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
daemon
=
True
def
run
(
self
):
try
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
self
.
op
.
run
(
feed_dict
=
feed
,
session
=
self
.
sess
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
self
.
sess
.
run
(
self
.
close_op
)
self
.
coord
.
request_stop
()
class
QueueInputTrainer
(
Trainer
):
"""
...
...
tensorpack/utils/concurrency.py
View file @
ef4a15ca
...
...
@@ -5,10 +5,11 @@
import
threading
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
atexit
import
weakref
from
six.moves
import
zip
from
.naming
import
*
from
.
import
logger
class
StoppableThread
(
threading
.
Thread
):
def
__init__
(
self
):
...
...
@@ -22,31 +23,19 @@ class StoppableThread(threading.Thread):
return
self
.
_stop
.
isSet
()
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
enqueue_op
,
raw_input_var
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
sess
=
trainer
.
sess
self
.
coord
=
trainer
.
coord
self
.
dataflow
=
trainer
.
config
.
dataset
self
.
input_vars
=
raw_input_var
self
.
op
=
enqueue_op
self
.
queue
=
queue
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
def
ensure_proc_terminate
(
proc
):
def
stop_proc_by_weak_ref
(
ref
):
proc
=
ref
()
if
proc
is
None
:
return
if
not
proc
.
is_alive
():
return
proc
.
terminate
()
proc
.
join
()
self
.
daemon
=
True
assert
isinstance
(
proc
,
multiprocessing
.
Process
)
atexit
.
register
(
stop_proc_by_weak_ref
,
weakref
.
ref
(
proc
))
def
run
(
self
):
try
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
self
.
op
.
run
(
feed_dict
=
feed
,
session
=
self
.
sess
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
self
.
sess
.
run
(
self
.
close_op
)
self
.
coord
.
request_stop
()
def
ensure_procs_terminate
(
procs
):
for
p
in
procs
:
ensure_proc_terminate
(
p
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment