Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3e97f126
Commit
3e97f126
authored
Feb 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use tf.train.MonitoredSession. enforce a boundary between graph finalize & session create
parent
224025e3
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
150 additions
and
114 deletions
+150
-114
README.md
README.md
+0
-3
examples/A3C-Gym/simulator.py
examples/A3C-Gym/simulator.py
+2
-2
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+5
-4
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+10
-9
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+4
-5
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+4
-4
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+3
-2
tensorpack/libinfo.py
tensorpack/libinfo.py
+1
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+11
-7
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+32
-24
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+2
-2
tensorpack/train/base.py
tensorpack/train/base.py
+18
-17
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+13
-24
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+2
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+8
-7
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+35
-1
No files found.
README.md
View file @
3e97f126
...
...
@@ -62,9 +62,6 @@ Dependencies:
+
Python 2 or 3
+
TensorFlow >= 1.0.0rc0
+
Python bindings for OpenCV
+
(optional) use tcmalloc if running with large data
```
pip install --user -U git+https://github.com/ppwwyyxx/tensorpack.git
pip install --user -r opt-requirements.txt # (some optional dependencies required by certain submodules, you can install later if prompted)
```
examples/A3C-Gym/simulator.py
View file @
3e97f126
...
...
@@ -23,8 +23,8 @@ from tensorpack.utils.serialize import loads, dumps
from
tensorpack.utils.concurrency
import
LoopThread
,
ensure_proc_terminate
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
,
'SimulatorProcessStateExchange'
,
'SimulatorProcessSharedWeight'
,
'TransitionExperience'
,
'WeightSync'
]
'SimulatorProcessStateExchange'
,
'TransitionExperience'
]
class
TransitionExperience
(
object
):
...
...
examples/A3C-Gym/train-atari.py
View file @
3e97f126
...
...
@@ -38,7 +38,7 @@ CHANNEL = FRAME_HISTORY * 3
IMAGE_SHAPE3
=
IMAGE_SIZE
+
(
CHANNEL
,)
LOCAL_TIME_MAX
=
5
STEP_PER_EPOCH
=
6000
STEP
S
_PER_EPOCH
=
6000
EVAL_EPISODE
=
50
BATCH_SIZE
=
128
SIMULATOR_PROC
=
50
...
...
@@ -150,11 +150,12 @@ class MySimulatorMaster(SimulatorMaster, Callback):
self
.
queue
=
queue
.
Queue
(
maxsize
=
BATCH_SIZE
*
8
*
2
)
def
_setup_graph
(
self
):
self
.
sess
=
self
.
trainer
.
sess
self
.
async_predictor
=
MultiThreadAsyncPredictor
(
self
.
trainer
.
get_predict_funcs
([
'state'
],
[
'logitsT'
,
'pred_value'
],
PREDICTOR_THREAD
),
batch_size
=
15
)
self
.
async_predictor
.
run
()
def
_before_train
(
self
):
self
.
async_predictor
.
start
()
def
_on_state
(
self
,
state
,
ident
):
def
cb
(
outputs
):
...
...
@@ -222,7 +223,7 @@ def get_config():
],
session_config
=
get_default_sess_config
(
0.5
),
model
=
M
,
steps_per_epoch
=
STEP_PER_EPOCH
,
steps_per_epoch
=
STEP
S
_PER_EPOCH
,
max_epoch
=
1000
,
)
...
...
examples/DeepQNetwork/common.py
View file @
3e97f126
...
...
@@ -40,7 +40,7 @@ def play_model(cfg):
def
eval_with_funcs
(
predict_funcs
,
nr_eval
):
class
Worker
(
StoppableThread
):
class
Worker
(
StoppableThread
,
ShareSessionThread
):
def
__init__
(
self
,
func
,
queue
):
super
(
Worker
,
self
)
.
__init__
()
self
.
_func
=
func
...
...
@@ -52,14 +52,15 @@ def eval_with_funcs(predict_funcs, nr_eval):
return
self
.
_func
(
*
args
,
**
kwargs
)
def
run
(
self
):
player
=
get_player
(
train
=
False
)
while
not
self
.
stopped
():
try
:
score
=
play_one_episode
(
player
,
self
.
func
)
# print "Score, ", score
except
RuntimeError
:
return
self
.
queue_put_stoppable
(
self
.
q
,
score
)
with
self
.
default_sess
():
player
=
get_player
(
train
=
False
)
while
not
self
.
stopped
():
try
:
score
=
play_one_episode
(
player
,
self
.
func
)
# print "Score, ", score
except
RuntimeError
:
return
self
.
queue_put_stoppable
(
self
.
q
,
score
)
q
=
queue
.
Queue
()
threads
=
[
Worker
(
f
,
q
)
for
f
in
predict_funcs
]
...
...
examples/DeepQNetwork/expreplay.py
View file @
3e97f126
...
...
@@ -11,7 +11,7 @@ from six.moves import queue
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.utils
import
logger
,
get_tqdm
,
get_rng
from
tensorpack.utils.concurrency
import
LoopThread
from
tensorpack.utils.concurrency
import
LoopThread
,
ShareSessionThread
from
tensorpack.callbacks.base
import
Callback
__all__
=
[
'ExpReplay'
]
...
...
@@ -75,10 +75,9 @@ class ExpReplay(DataFlow, Callback):
# spawn a separate thread to run policy, can speed up 1.3x
def
populate_job_func
():
self
.
_populate_job_queue
.
get
()
with
self
.
trainer
.
sess
.
as_default
():
for
_
in
range
(
self
.
update_frequency
):
self
.
_populate_exp
()
th
=
LoopThread
(
populate_job_func
,
pausable
=
False
)
for
_
in
range
(
self
.
update_frequency
):
self
.
_populate_exp
()
th
=
ShareSessionThread
(
LoopThread
(
populate_job_func
,
pausable
=
False
))
th
.
name
=
"SimulatorThread"
return
th
...
...
examples/ResNet/cifar10-resnet.py
View file @
3e97f126
...
...
@@ -21,9 +21,9 @@ This implementation uses the variants proposed in:
Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results on 2 TitanX for
n=5, about 7.1
%
val error after 67k steps (
8.6
step/s)
n=18, about 5.95
%
val error after 80k steps (
2.6
step/s)
n=30: a 182-layer network, about 5.6
%
val error after 51k steps (
1.5
5 step/s)
n=5, about 7.1
%
val error after 67k steps (
15
step/s)
n=18, about 5.95
%
val error after 80k steps (
4.2
step/s)
n=30: a 182-layer network, about 5.6
%
val error after 51k steps (
2.
5 step/s)
This model uses the whole training set instead of a train-val split.
To train:
...
...
@@ -131,7 +131,7 @@ def get_data(train_or_test):
imgaug
.
MapImage
(
lambda
x
:
x
-
pp_mean
)
]
ds
=
AugmentImageComponent
(
ds
,
augmentors
)
ds
=
BatchData
(
ds
,
128
,
remainder
=
not
isTrain
)
ds
=
BatchData
(
ds
,
BATCH_SIZE
,
remainder
=
not
isTrain
)
if
isTrain
:
ds
=
PrefetchData
(
ds
,
3
,
2
)
return
ds
...
...
tensorpack/callbacks/inference_runner.py
View file @
3e97f126
...
...
@@ -226,18 +226,19 @@ class FeedfreeInferenceRunner(Triggerable):
G
=
tf
.
get_default_graph
()
self
.
_output_tensors
=
[
G
.
get_tensor_by_name
(
self
.
_tower_prefix
+
'/'
+
n
)
for
n
in
all_names
]
self
.
_sess
=
self
.
trainer
.
sess
# list of list of id
self
.
inf_to_idxs
=
dispatcer
.
get_idx_for_each_entry
()
def
_trigger
(
self
):
sess
=
tf
.
get_default_session
()
for
inf
in
self
.
infs
:
inf
.
before_inference
()
with
get_tqdm
(
total
=
self
.
_size
)
as
pbar
:
for
_
in
range
(
self
.
_size
):
outputs
=
se
lf
.
_se
ss
.
run
(
fetches
=
self
.
_output_tensors
)
outputs
=
sess
.
run
(
fetches
=
self
.
_output_tensors
)
for
inf
,
idlist
in
zip
(
self
.
infs
,
self
.
inf_to_idxs
):
inf_output
=
[
outputs
[
k
]
for
k
in
idlist
]
inf
.
datapoint
(
inf_output
)
...
...
tensorpack/libinfo.py
View file @
3e97f126
...
...
@@ -6,4 +6,4 @@ import cv2 # noqa
import
os
os
.
environ
[
'OPENCV_OPENCL_RUNTIME'
]
=
''
__version__
=
'0.1.
5
'
__version__
=
'0.1.
6
'
tensorpack/predict/base.py
View file @
3e97f126
...
...
@@ -23,7 +23,6 @@ class PredictorBase(object):
Base class for all predictors.
Attributes:
session (tf.Session):
return_input (bool): whether the call will also return (inputs, outputs)
or just outpus
"""
...
...
@@ -91,25 +90,30 @@ class AsyncPredictorBase(PredictorBase):
class
OnlinePredictor
(
PredictorBase
):
""" A predictor which directly use an existing session. """
def
__init__
(
self
,
sess
,
input_tensors
,
output_tensors
,
return_input
=
False
):
def
__init__
(
self
,
input_tensors
,
output_tensors
,
return_input
=
False
,
sess
=
None
):
"""
Args:
sess (tf.Session): an existing session.
input_tensors (list): list of names.
output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`.
sess (tf.Session): the session this predictor runs in. If None,
will use the default session.
"""
self
.
session
=
sess
self
.
return_input
=
return_input
self
.
input_tensors
=
input_tensors
self
.
output_tensors
=
output_tensors
self
.
sess
=
sess
def
_do_call
(
self
,
dp
):
assert
len
(
dp
)
==
len
(
self
.
input_tensors
),
\
"{} != {}"
.
format
(
len
(
dp
),
len
(
self
.
input_tensors
))
feed
=
dict
(
zip
(
self
.
input_tensors
,
dp
))
output
=
self
.
session
.
run
(
self
.
output_tensors
,
feed_dict
=
feed
)
if
self
.
sess
is
None
:
sess
=
tf
.
get_default_session
()
else
:
sess
=
self
.
sess
output
=
sess
.
run
(
self
.
output_tensors
,
feed_dict
=
feed
)
return
output
...
...
@@ -133,7 +137,7 @@ class OfflinePredictor(OnlinePredictor):
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
super
(
OfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
input_vars
,
output_vars
,
config
.
return_input
,
sess
)
def
get_predict_func
(
config
):
...
...
tensorpack/predict/concurrency.py
View file @
3e97f126
...
...
@@ -9,9 +9,9 @@ from six.moves import queue, range
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils.concurrency
import
DIE
,
StoppableThread
from
..utils.concurrency
import
DIE
,
StoppableThread
,
ShareSessionThread
from
..tfutils.modelutils
import
describe_model
from
.base
import
OfflinePredictor
,
AsyncPredictorBase
from
.base
import
O
nlinePredictor
,
O
fflinePredictor
,
AsyncPredictorBase
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
'MultiThreadAsyncPredictor'
]
...
...
@@ -73,7 +73,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self
.
outqueue
.
put
((
tid
,
self
.
predictor
(
dp
)))
class
PredictorWorkerThread
(
StoppableThread
):
class
PredictorWorkerThread
(
StoppableThread
,
ShareSessionThread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
queue
=
queue
...
...
@@ -83,25 +83,26 @@ class PredictorWorkerThread(StoppableThread):
self
.
id
=
id
def
run
(
self
):
while
not
self
.
stopped
():
batched
,
futures
=
self
.
fetch_batch
()
try
:
outputs
=
self
.
func
(
batched
)
except
tf
.
errors
.
CancelledError
:
for
f
in
futures
:
f
.
cancel
()
logger
.
warn
(
"In PredictorWorkerThread id={}, call was cancelled."
.
format
(
self
.
id
))
return
# print "Worker {} batched {} Queue {}".format(
# self.id, len(futures), self.queue.qsize())
# debug, for speed testing
# if not hasattr(self, 'xxx'):
# self.xxx = outputs = self.func(batched)
# else:
# outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
with
self
.
default_sess
():
while
not
self
.
stopped
():
batched
,
futures
=
self
.
fetch_batch
()
try
:
outputs
=
self
.
func
(
batched
)
except
tf
.
errors
.
CancelledError
:
for
f
in
futures
:
f
.
cancel
()
logger
.
warn
(
"In PredictorWorkerThread id={}, call was cancelled."
.
format
(
self
.
id
))
return
# print "Worker {} batched {} Queue {}".format(
# self.id, len(futures), self.queue.qsize())
# debug, for speed testing
# if not hasattr(self, 'xxx'):
# self.xxx = outputs = self.func(batched)
# else:
# outputs = [[self.xxx[0][0]] * len(batched[0]), [self.xxx[1][0]] * len(batched[0])]
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
def
fetch_batch
(
self
):
""" Fetch a batch of data without waiting"""
...
...
@@ -137,9 +138,12 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
batch_size (int): the maximum of an internal batch.
"""
assert
len
(
predictors
)
self
.
_need_default_sess
=
False
for
k
in
predictors
:
# assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
assert
isinstance
(
k
,
OnlinePredictor
),
type
(
k
)
if
k
.
sess
is
None
:
self
.
_need_default_sess
=
True
# TODO support predictors.return_input here
assert
not
k
.
return_input
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
threads
=
[
...
...
@@ -153,6 +157,10 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
options
.
parse_command_line
([
'--logging=debug'
])
def
start
(
self
):
if
self
.
_need_default_sess
:
assert
tf
.
get_default_session
()
is
not
None
,
\
"Not session is bind to predictors, "
\
"MultiThreadAsyncPredictor.start() has to be called under a default session!"
for
t
in
self
.
threads
:
t
.
start
()
...
...
tensorpack/predict/multigpu.py
View file @
3e97f126
...
...
@@ -41,7 +41,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
[
'{}{}/'
.
format
(
PREDICT_TOWER
,
k
)
+
n
for
n
in
config
.
output_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
input_vars
,
output_vars
,
config
.
return_input
,
self
.
sess
))
def
_do_call
(
self
,
dp
):
# use the first tower for compatible PredictorBase interface
...
...
@@ -90,4 +90,4 @@ class DataParallelOfflinePredictor(OnlinePredictor):
input_vars
=
get_tensors_by_names
(
input_var_names
)
config
.
session_init
.
init
(
sess
)
super
(
DataParallelOfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
input_vars
,
output_vars
,
config
.
return_input
,
sess
)
tensorpack/train/base.py
View file @
3e97f126
...
...
@@ -35,7 +35,6 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
coord (tf.train.Coordinator)
stat_holder (StatHolder)
summary_writer (tf.summary.FileWriter)
...
...
@@ -53,10 +52,8 @@ class Trainer(object):
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
self
.
model
=
config
.
model
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
coord
=
tf
.
train
.
Coordinator
()
self
.
epoch_num
=
self
.
config
.
starting_epoch
self
.
epoch_num
=
self
.
config
.
starting_epoch
-
1
self
.
local_step
=
0
def
train
(
self
):
...
...
@@ -131,24 +128,29 @@ class Trainer(object):
describe_model
()
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks ..."
)
logger
.
info
(
"Setup callbacks
graph
..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_extra_fetches
=
self
.
config
.
callbacks
.
extra_fetches
()
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
self
.
sess
.
graph
)
self
.
summary_op
=
tf
.
summary
.
merge_all
()
logger
.
info
(
"Setup summaries ..."
)
self
.
summary_writer
=
tf
.
summary
.
FileWriter
(
logger
.
LOG_DIR
,
graph
=
tf
.
get_default_graph
())
self
.
summary_op
=
tf
.
summary
.
merge_all
()
# XXX not good
# create an empty StatHolder
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
logger
.
info
(
"Initializing graph variables ..."
)
initop
=
tf
.
global_variables_initializer
()
self
.
sess
.
run
(
initop
)
def
after_init
(
_
,
__
):
logger
.
info
(
"Graph variables initialized."
)
scaffold
=
tf
.
train
.
Scaffold
(
init_op
=
tf
.
global_variables_initializer
(),
init_fn
=
after_init
)
logger
.
info
(
"Finalize the graph, create the session ..."
)
self
.
monitored_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
tf
.
train
.
ChiefSessionCreator
(
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
hooks
=
None
)
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
self
.
config
.
session_init
.
init
(
self
.
sess
)
tf
.
get_default_graph
()
.
finalize
()
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
@
abstractmethod
def
_setup
(
self
):
""" setup Trainer-specific stuff for training"""
...
...
@@ -176,7 +178,7 @@ class Trainer(object):
logger
.
info
(
"Start Epoch {} ..."
.
format
(
self
.
epoch_num
))
start_time
=
time
.
time
()
for
self
.
local_step
in
range
(
self
.
config
.
steps_per_epoch
):
if
self
.
coord
.
should_stop
():
if
self
.
monitored_sess
.
should_stop
():
return
fetch_data
=
self
.
run_step
()
# implemented by subclass
if
fetch_data
is
None
:
...
...
@@ -197,9 +199,8 @@ class Trainer(object):
raise
finally
:
callbacks
.
after_train
()
self
.
coord
.
request_stop
()
self
.
summary_writer
.
close
()
self
.
sess
.
close
()
self
.
monitored_
sess
.
close
()
def
get_predict_func
(
self
,
input_names
,
output_names
):
"""
...
...
tensorpack/train/input_data.py
View file @
3e97f126
...
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
threading
from
abc
import
ABCMeta
,
abstractmethod
import
six
...
...
@@ -12,6 +11,7 @@ from ..dataflow import DataFlow, RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils
import
get_op_tensor_name
from
..utils
import
logger
from
..utils.concurrency
import
ShareSessionThread
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'InputData'
,
'FeedfreeInput'
,
...
...
@@ -72,8 +72,8 @@ class FeedfreeInput(InputData):
"""
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
trainer
,
queue
,
ds
,
input_placehdrs
):
class
EnqueueThread
(
ShareSession
Thread
):
def
__init__
(
self
,
queue
,
ds
,
input_placehdrs
):
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
name
=
'EnqueueThread'
self
.
daemon
=
True
...
...
@@ -81,8 +81,6 @@ class EnqueueThread(threading.Thread):
self
.
dataflow
=
ds
self
.
queue
=
queue
self
.
sess
=
trainer
.
sess
self
.
coord
=
trainer
.
coord
self
.
placehdrs
=
input_placehdrs
self
.
op
=
self
.
queue
.
enqueue
(
self
.
placehdrs
)
...
...
@@ -92,27 +90,20 @@ class EnqueueThread(threading.Thread):
self
.
size_op
,
tf
.
float32
,
name
=
'input_queue_size'
))
def
run
(
self
):
try
:
self
.
dataflow
.
reset_state
()
with
self
.
sess
.
as_default
():
with
self
.
default_sess
()
:
try
:
self
.
dataflow
.
reset_state
()
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
:
pass
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
finally
:
self
.
coord
.
request_stop
()
try
:
self
.
sess
.
run
(
self
.
close_op
)
except
RuntimeError
:
# session already closed
except
tf
.
errors
.
CancelledError
:
pass
logger
.
info
(
"Enqueue Thread Exited."
)
except
Exception
:
logger
.
exception
(
"Exception in EnqueueThread:"
)
finally
:
logger
.
info
(
"EnqueueThread Exited."
)
class
QueueInput
(
FeedfreeInput
):
...
...
@@ -141,8 +132,7 @@ class QueueInput(FeedfreeInput):
self
.
queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
name
=
'input_queue'
)
self
.
thread
=
EnqueueThread
(
trainer
,
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
def
_get_input_tensors
(
self
):
...
...
@@ -203,8 +193,7 @@ class BatchQueueInput(FeedfreeInput):
for
shp
in
self
.
queue
.
shapes
:
assert
shp
.
is_fully_defined
(),
shape_err
self
.
thread
=
EnqueueThread
(
trainer
,
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
def
_get_input_tensors
(
self
):
...
...
tensorpack/train/multigpu.py
View file @
3e97f126
...
...
@@ -199,7 +199,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradproc
=
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
log
=
False
)
grad_list
=
[
apply_grad_processors
(
g
,
[
gradproc
])
for
g
in
grad_list
]
grad_list
=
apply_grad_processors
(
grad_list
,
[
gradproc
])
# use grad from the first tower for iteration in main thread
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
0
],
name
=
'min_op'
)
...
...
@@ -216,7 +216,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
def
f
(
op
=
train_op
):
# avoid late-binding
self
.
sess
.
run
([
op
])
next
(
self
.
async_step_counter
)
next
(
self
.
async_step_counter
)
# atomic due to GIL
th
=
LoopThread
(
f
)
th
.
name
=
"AsyncLoopThread-{}"
.
format
(
k
)
th
.
pause
()
...
...
tensorpack/train/trainer.py
View file @
3e97f126
...
...
@@ -18,19 +18,20 @@ __all__ = ['SimpleTrainer', 'MultiPredictorTowerTrainer']
class
PredictorFactory
(
object
):
""" Make predictors for a trainer"""
def
__init__
(
self
,
sess
,
model
,
towers
):
def
__init__
(
self
,
model
,
towers
):
"""
:param towers: list of gpu relative id
"""
self
.
sess
=
sess
self
.
model
=
model
self
.
towers
=
towers
self
.
tower_built
=
False
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
"""
:param tower: need the kth tower (not the gpu id)
:returns: an online predictor
Args:
tower: need the kth tower (not the gpu id)
Returns:
an online predictor (which has to be used under a default session)
"""
if
not
self
.
tower_built
:
self
.
_build_predict_tower
()
...
...
@@ -53,7 +54,7 @@ class PredictorFactory(object):
output_names
=
map
(
get_name_in_tower
,
output_names
)
output_vars
=
get_tensors_by_names
(
output_names
)
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
return
OnlinePredictor
(
raw_input_vars
,
output_vars
)
def
_build_predict_tower
(
self
):
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
...
...
@@ -76,7 +77,7 @@ class SimpleTrainer(Trainer):
config (TrainConfig): the training config.
"""
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
self
.
_predictor_factory
=
PredictorFactory
(
self
.
model
,
[
0
])
if
config
.
dataflow
is
None
:
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
FeedInput
),
type
(
self
.
_input_method
)
...
...
@@ -118,7 +119,7 @@ class MultiPredictorTowerTrainer(Trainer):
def
_setup_predictor_factory
(
self
):
# by default, use the first training gpu for prediction
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
self
.
config
.
predict_tower
)
self
.
model
,
self
.
config
.
predict_tower
)
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
...
...
tensorpack/utils/concurrency.py
View file @
3e97f126
...
...
@@ -21,7 +21,8 @@ else:
import
subprocess
__all__
=
[
'StoppableThread'
,
'LoopThread'
,
'ensure_proc_terminate'
,
__all__
=
[
'StoppableThread'
,
'LoopThread'
,
'ShareSessionThread'
,
'ensure_proc_terminate'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
'mask_sigint'
,
'start_proc_mask_signal'
]
...
...
@@ -97,6 +98,39 @@ class LoopThread(StoppableThread):
self
.
_lock
.
release
()
class
ShareSessionThread
(
threading
.
Thread
):
""" A wrapper around thread so that the thread
uses the default session at "start()" time.
"""
def
__init__
(
self
,
th
=
None
):
"""
Args:
th (threading.Thread or None):
"""
super
(
ShareSessionThread
,
self
)
.
__init__
()
if
th
is
not
None
:
assert
isinstance
(
th
,
threading
.
Thread
),
th
self
.
_th
=
th
self
.
name
=
th
.
name
self
.
daemon
=
th
.
daemon
@
contextmanager
def
default_sess
(
self
):
with
self
.
_sess
.
as_default
():
yield
def
start
(
self
):
import
tensorflow
as
tf
self
.
_sess
=
tf
.
get_default_session
()
super
(
ShareSessionThread
,
self
)
.
start
()
def
run
(
self
):
if
not
self
.
_th
:
raise
NotImplementedError
()
with
self
.
_sess
.
as_default
():
self
.
_th
.
run
()
class
DIE
(
object
):
""" A placeholder class indicating end of queue """
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment