Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b9e2bd1b
Commit
b9e2bd1b
authored
Jun 08, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'better-a3c'
parents
13d4171f
313723df
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
315 additions
and
214 deletions
+315
-214
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+0
-4
examples/Inception/inception-bn.py
examples/Inception/inception-bn.py
+6
-7
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+11
-18
examples/ResNet/svhn-resnet.py
examples/ResNet/svhn-resnet.py
+11
-20
requirements.txt
requirements.txt
+2
-1
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+13
-3
tensorpack/__init__.py
tensorpack/__init__.py
+1
-0
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+1
-1
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+2
-3
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+32
-3
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+10
-15
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+29
-2
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+3
-1
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+1
-3
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+131
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+51
-122
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+8
-8
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+2
-2
No files found.
examples/Atari2600/DQN.py
View file @
b9e2bd1b
...
@@ -14,12 +14,8 @@ import multiprocessing, threading
...
@@ -14,12 +14,8 @@ import multiprocessing, threading
from
collections
import
deque
from
collections
import
deque
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.utils.concurrency
import
*
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.tfutils
import
symbolic_functions
as
symbf
from
tensorpack.callbacks
import
*
from
tensorpack.RL
import
*
from
tensorpack.RL
import
*
import
common
import
common
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
from
common
import
play_model
,
Evaluator
,
eval_model_multithread
...
...
examples/Inception/inception-bn.py
View file @
b9e2bd1b
...
@@ -196,10 +196,9 @@ if __name__ == '__main__':
...
@@ -196,10 +196,9 @@ if __name__ == '__main__':
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
QueueInput
Trainer
(
config
)
.
train
()
SyncMultiGPU
Trainer
(
config
)
.
train
()
examples/ResNet/cifar10-resnet.py
View file @
b9e2bd1b
...
@@ -8,14 +8,9 @@ import tensorflow as tf
...
@@ -8,14 +8,9 @@ import tensorflow as tf
import
argparse
import
argparse
import
os
import
os
from
tensorpack.train
import
TrainConfig
,
QueueInputTrainer
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils
import
*
from
tensorpack.tfutils
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
*
"""
"""
CIFAR10-resnet example.
CIFAR10-resnet example.
...
@@ -146,6 +141,10 @@ def get_data(train_or_test):
...
@@ -146,6 +141,10 @@ def get_data(train_or_test):
return
ds
return
ds
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
...
@@ -179,18 +178,12 @@ if __name__ == '__main__':
...
@@ -179,18 +178,12 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
with
tf
.
device
(
'/cpu:0'
):
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
QueueInput
Trainer
(
config
)
.
train
()
SyncMultiGPU
Trainer
(
config
)
.
train
()
examples/ResNet/svhn-resnet.py
View file @
b9e2bd1b
...
@@ -8,16 +8,9 @@ import argparse
...
@@ -8,16 +8,9 @@ import argparse
import
numpy
as
np
import
numpy
as
np
import
os
import
os
from
tensorpack.train
import
TrainConfig
,
QueueInputTrainer
from
tensorpack
import
*
from
tensorpack.models
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.utils
import
*
from
tensorpack.tfutils
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow
import
imgaug
"""
"""
ResNet-110 for SVHN Digit Classification.
ResNet-110 for SVHN Digit Classification.
...
@@ -151,6 +144,10 @@ def get_data(train_or_test):
...
@@ -151,6 +144,10 @@ def get_data(train_or_test):
return
ds
return
ds
def
get_config
():
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
# prepare dataset
dataset_train
=
get_data
(
'train'
)
dataset_train
=
get_data
(
'train'
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
...
@@ -184,18 +181,12 @@ if __name__ == '__main__':
...
@@ -184,18 +181,12 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
if
args
.
gpu
:
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
with
tf
.
device
(
'/cpu:0'
):
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
QueueInput
Trainer
(
config
)
.
train
()
SyncMultiGPU
Trainer
(
config
)
.
train
()
requirements.txt
View file @
b9e2bd1b
...
@@ -2,4 +2,5 @@ termcolor
...
@@ -2,4 +2,5 @@ termcolor
pillow
pillow
scipy
scipy
tqdm
tqdm
dill
msgpack
msgpack-numpy
tensorpack/RL/simulator.py
View file @
b9e2bd1b
...
@@ -9,6 +9,7 @@ import threading
...
@@ -9,6 +9,7 @@ import threading
import
weakref
import
weakref
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
,
namedtuple
from
collections
import
defaultdict
,
namedtuple
import
numpy
as
np
from
six.moves
import
queue
from
six.moves
import
queue
from
..utils.timer
import
*
from
..utils.timer
import
*
...
@@ -42,7 +43,7 @@ class SimulatorProcess(multiprocessing.Process):
...
@@ -42,7 +43,7 @@ class SimulatorProcess(multiprocessing.Process):
context
=
zmq
.
Context
()
context
=
zmq
.
Context
()
c2s_socket
=
context
.
socket
(
zmq
.
DEALER
)
c2s_socket
=
context
.
socket
(
zmq
.
DEALER
)
c2s_socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
c2s_socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
#
c2s_socket.set_hwm(2)
c2s_socket
.
set_hwm
(
2
)
c2s_socket
.
connect
(
self
.
c2s
)
c2s_socket
.
connect
(
self
.
c2s
)
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
...
@@ -59,7 +60,8 @@ class SimulatorProcess(multiprocessing.Process):
...
@@ -59,7 +60,8 @@ class SimulatorProcess(multiprocessing.Process):
action
=
loads
(
data
)
action
=
loads
(
data
)
reward
,
isOver
=
player
.
action
(
action
)
reward
,
isOver
=
player
.
action
(
action
)
c2s_socket
.
send
(
dumps
((
reward
,
isOver
)),
copy
=
False
)
c2s_socket
.
send
(
dumps
((
reward
,
isOver
)),
copy
=
False
)
noop
=
s2c_socket
.
recv
(
copy
=
False
)
#with total_timer('client recv_ack'):
ACK
=
s2c_socket
.
recv
(
copy
=
False
)
#cnt += 1
#cnt += 1
#if cnt % 100 == 0:
#if cnt % 100 == 0:
#print_total_timer()
#print_total_timer()
...
@@ -102,6 +104,14 @@ class SimulatorMaster(threading.Thread):
...
@@ -102,6 +104,14 @@ class SimulatorMaster(threading.Thread):
self
.
socket_lock
=
threading
.
Lock
()
self
.
socket_lock
=
threading
.
Lock
()
self
.
daemon
=
True
self
.
daemon
=
True
# queueing messages to client
self
.
send_queue
=
queue
.
Queue
(
maxsize
=
100
)
self
.
send_thread
=
LoopThread
(
lambda
:
self
.
s2c_socket
.
send_multipart
(
self
.
send_queue
.
get
()))
self
.
send_thread
.
daemon
=
True
self
.
send_thread
.
start
()
# make sure socket get closed at the end
def
clean_context
(
soks
,
context
):
def
clean_context
(
soks
,
context
):
for
s
in
soks
:
for
s
in
soks
:
s
.
close
()
s
.
close
()
...
@@ -113,7 +123,6 @@ class SimulatorMaster(threading.Thread):
...
@@ -113,7 +123,6 @@ class SimulatorMaster(threading.Thread):
self
.
clients
=
defaultdict
(
SimulatorMaster
.
ClientState
)
self
.
clients
=
defaultdict
(
SimulatorMaster
.
ClientState
)
while
True
:
while
True
:
ident
,
msg
=
self
.
c2s_socket
.
recv_multipart
()
ident
,
msg
=
self
.
c2s_socket
.
recv_multipart
()
#assert _ == ""
client
=
self
.
clients
[
ident
]
client
=
self
.
clients
[
ident
]
client
.
protocol_state
=
1
-
client
.
protocol_state
# first flip the state
client
.
protocol_state
=
1
-
client
.
protocol_state
# first flip the state
if
not
client
.
protocol_state
==
0
:
# state-action
if
not
client
.
protocol_state
==
0
:
# state-action
...
@@ -126,6 +135,7 @@ class SimulatorMaster(threading.Thread):
...
@@ -126,6 +135,7 @@ class SimulatorMaster(threading.Thread):
self
.
_on_episode_over
(
ident
)
self
.
_on_episode_over
(
ident
)
else
:
else
:
self
.
_on_datapoint
(
ident
)
self
.
_on_datapoint
(
ident
)
self
.
send_queue
.
put
([
ident
,
'Thanks'
])
# just an ACK
@
abstractmethod
@
abstractmethod
def
_on_state
(
self
,
state
,
ident
):
def
_on_state
(
self
,
state
,
ident
):
...
...
tensorpack/__init__.py
View file @
b9e2bd1b
...
@@ -18,3 +18,4 @@ from .utils import *
...
@@ -18,3 +18,4 @@ from .utils import *
from
.tfutils
import
*
from
.tfutils
import
*
from
.callbacks
import
*
from
.callbacks
import
*
from
.dataflow
import
*
from
.dataflow
import
*
from
.predict
import
*
tensorpack/callbacks/common.py
View file @
b9e2bd1b
...
@@ -36,7 +36,7 @@ class ModelSaver(Callback):
...
@@ -36,7 +36,7 @@ class ModelSaver(Callback):
var_dict
=
{}
var_dict
=
{}
for
v
in
vars
:
for
v
in
vars
:
name
=
v
.
op
.
name
name
=
v
.
op
.
name
if
re
.
match
(
'tower[1-9]'
,
name
):
if
re
.
match
(
'tower[
p
1-9]'
,
name
):
#logger.info("Skip {} when saving model.".format(name))
#logger.info("Skip {} when saving model.".format(name))
continue
continue
if
'tower0/'
in
name
:
if
'tower0/'
in
name
:
...
...
tensorpack/dataflow/prefetch.py
View file @
b9e2bd1b
...
@@ -10,7 +10,7 @@ import uuid
...
@@ -10,7 +10,7 @@ import uuid
import
os
import
os
from
.base
import
ProxyDataFlow
from
.base
import
ProxyDataFlow
from
..utils.concurrency
import
ensure_proc_terminate
from
..utils.concurrency
import
*
from
..utils.serialize
import
*
from
..utils.serialize
import
*
from
..utils
import
logger
from
..utils
import
logger
...
@@ -107,8 +107,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -107,8 +107,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
self
.
procs
=
[
PrefetchProcessZMQ
(
self
.
ds
,
self
.
pipename
)
self
.
procs
=
[
PrefetchProcessZMQ
(
self
.
ds
,
self
.
pipename
)
for
_
in
range
(
self
.
nr_proc
)]
for
_
in
range
(
self
.
nr_proc
)]
for
x
in
self
.
procs
:
start_proc_mask_signal
(
self
.
procs
)
x
.
start
()
# __del__ not guranteed to get called at exit
# __del__ not guranteed to get called at exit
import
atexit
import
atexit
...
...
tensorpack/models/batch_norm.py
View file @
b9e2bd1b
...
@@ -5,7 +5,9 @@
...
@@ -5,7 +5,9 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
copy
import
copy
from
copy
import
copy
import
re
from
..utils
import
logger
from
._common
import
layer_register
from
._common
import
layer_register
__all__
=
[
'BatchNorm'
]
__all__
=
[
'BatchNorm'
]
...
@@ -48,9 +50,35 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -48,9 +50,35 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
else
:
else
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
keep_dims
=
False
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
)
emaname
=
'EMA'
if
not
batch_mean
.
name
.
startswith
(
'towerp'
):
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
else
:
# use training-statistics in prediction
assert
not
use_local_stat
# have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
G
=
tf
.
get_default_graph
()
try
:
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_mean
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_var
.
name
)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
except
KeyError
:
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
ema_mean
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
ema_var
.
name
)
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
if
use_local_stat
:
if
use_local_stat
:
with
tf
.
control_dependencies
([
ema_apply_op
]):
with
tf
.
control_dependencies
([
ema_apply_op
]):
...
@@ -58,6 +86,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -58,6 +86,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'bn'
)
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'bn'
)
else
:
else
:
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
# XXX TODO batch==1?
mean
,
var
=
ema_mean
,
ema_var
*
batch
/
(
batch
-
1
)
# unbiased variance estimator
mean
,
var
=
ema_mean
,
ema_var
*
batch
/
(
batch
-
1
)
# unbiased variance estimator
return
tf
.
nn
.
batch_normalization
(
return
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
'bn'
)
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
'bn'
)
tensorpack/predict/concurrency.py
View file @
b9e2bd1b
...
@@ -81,43 +81,38 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
...
@@ -81,43 +81,38 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
class
PredictorWorkerThread
(
threading
.
Thread
):
class
PredictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
queue
=
queue
self
.
queue
=
queue
self
.
func
=
pred_func
self
.
func
=
pred_func
self
.
daemon
=
True
self
.
daemon
=
True
self
.
batch_size
=
batch_size
self
.
id
=
id
self
.
id
=
id
def
run
(
self
):
def
run
(
self
):
#self.xxx = None
def
fetch
():
def
fetch
():
batched
=
[]
batched
,
futures
=
[],
[]
futures
=
[]
inp
,
f
=
self
.
queue
.
get
()
inp
,
f
=
self
.
queue
.
get
()
batched
.
append
(
inp
)
batched
.
append
(
inp
)
futures
.
append
(
f
)
futures
.
append
(
f
)
#print "func queue:", self.queue.qsize()
if
self
.
batch_size
==
1
:
#
return batched, futures
return
batched
,
futures
while
True
:
while
True
:
try
:
try
:
inp
,
f
=
self
.
queue
.
get_nowait
()
inp
,
f
=
self
.
queue
.
get_nowait
()
batched
.
append
(
inp
)
batched
.
append
(
inp
)
futures
.
append
(
f
)
futures
.
append
(
f
)
if
len
(
batched
)
==
5
:
if
len
(
batched
)
==
self
.
batch_size
:
break
break
except
queue
.
Empty
:
except
queue
.
Empty
:
break
break
return
batched
,
futures
return
batched
,
futures
#self.xxx = None
#self.xxx = None
while
True
:
while
True
:
# normal input
#inputs, f = self.queue.get()
#outputs = self.func(inputs)
#f.set_result(outputs)
batched
,
futures
=
fetch
()
batched
,
futures
=
fetch
()
#print "batched size: ", len(batched)
#print "batched size: ", len(batched)
, "queuesize: ", self.queue.qsize()
outputs
=
self
.
func
([
batched
])
outputs
=
self
.
func
([
batched
])
# debug, for speed testing
#if self.xxx is None:
#if self.xxx is None:
#outputs = self.func([batched])
#outputs = self.func([batched])
#self.xxx = outputs
#self.xxx = outputs
...
@@ -134,13 +129,13 @@ class MultiThreadAsyncPredictor(object):
...
@@ -134,13 +129,13 @@ class MultiThreadAsyncPredictor(object):
An online predictor (use the current active session) that works with
An online predictor (use the current active session) that works with
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
"""
"""
def
__init__
(
self
,
trainer
,
input_names
,
output_names
,
nr_thread
):
def
__init__
(
self
,
trainer
,
input_names
,
output_names
,
nr_thread
,
batch_size
=
5
):
"""
"""
:param trainer: a `QueueInputTrainer` instance.
:param trainer: a `QueueInputTrainer` instance.
"""
"""
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
nr_thread
*
10
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
nr_thread
*
10
)
self
.
threads
=
[
self
.
threads
=
[
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
)
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
batch_size
)
for
id
,
f
in
enumerate
(
for
id
,
f
in
enumerate
(
trainer
.
get_predict_funcs
(
trainer
.
get_predict_funcs
(
input_names
,
output_names
,
nr_thread
))]
input_names
,
output_names
,
nr_thread
))]
...
...
tensorpack/tfutils/common.py
View file @
b9e2bd1b
...
@@ -5,13 +5,19 @@
...
@@ -5,13 +5,19 @@
from
..utils.naming
import
*
from
..utils.naming
import
*
import
tensorflow
as
tf
import
tensorflow
as
tf
from
copy
import
copy
import
six
from
contextlib
import
contextmanager
__all__
=
[
'get_default_sess_config'
,
__all__
=
[
'get_default_sess_config'
,
'get_global_step'
,
'get_global_step'
,
'get_global_step_var'
,
'get_global_step_var'
,
'get_op_var_name'
,
'get_op_var_name'
,
'get_vars_by_names'
'get_vars_by_names'
,
]
'backup_collection'
,
'restore_collection'
,
'clear_collection'
,
'freeze_collection'
]
def
get_default_sess_config
(
mem_fraction
=
0.9
):
def
get_default_sess_config
(
mem_fraction
=
0.9
):
"""
"""
...
@@ -66,3 +72,24 @@ def get_vars_by_names(names):
...
@@ -66,3 +72,24 @@ def get_vars_by_names(names):
opn
,
varn
=
get_op_var_name
(
n
)
opn
,
varn
=
get_op_var_name
(
n
)
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
return
ret
return
ret
def
backup_collection
(
keys
):
ret
=
{}
for
k
in
keys
:
ret
[
k
]
=
copy
(
tf
.
get_collection
(
k
))
return
ret
def
restore_collection
(
backup
):
for
k
,
v
in
six
.
iteritems
(
backup
):
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
v
)
def
clear_collection
(
keys
):
for
k
in
keys
:
del
tf
.
get_collection_ref
(
k
)[:]
@
contextmanager
def
freeze_collection
(
keys
):
backup
=
backup_collection
(
keys
)
yield
restore_collection
(
backup
)
tensorpack/tfutils/sessinit.py
View file @
b9e2bd1b
...
@@ -106,8 +106,10 @@ class SaverRestore(SessionInit):
...
@@ -106,8 +106,10 @@ class SaverRestore(SessionInit):
var_dict
=
defaultdict
(
list
)
var_dict
=
defaultdict
(
list
)
for
v
in
vars_to_restore
:
for
v
in
vars_to_restore
:
name
=
v
.
op
.
name
name
=
v
.
op
.
name
if
'towerp'
in
name
:
logger
.
warn
(
"Anything from prediction tower shouldn't be saved."
)
if
'tower'
in
name
:
if
'tower'
in
name
:
new_name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
name
)
new_name
=
re
.
sub
(
'tower[
p
0-9]+/'
,
''
,
name
)
name
=
new_name
name
=
new_name
if
name
in
vars_available
:
if
name
in
vars_available
:
var_dict
[
name
]
.
append
(
v
)
var_dict
[
name
]
.
append
(
v
)
...
...
tensorpack/tfutils/summary.py
View file @
b9e2bd1b
...
@@ -90,7 +90,7 @@ def summary_moving_average():
...
@@ -90,7 +90,7 @@ def summary_moving_average():
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
vars_to_summary
=
tf
.
get_collection
(
MOVING_SUMMARY_VARS_KEY
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
avg_maintain_op
=
averager
.
apply
(
vars_to_summary
)
for
idx
,
c
in
enumerate
(
vars_to_summary
):
for
idx
,
c
in
enumerate
(
vars_to_summary
):
name
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
c
.
op
.
name
)
name
=
re
.
sub
(
'tower[
p
0-9]+/'
,
''
,
c
.
op
.
name
)
tf
.
scalar_summary
(
name
,
averager
.
average
(
c
))
tf
.
scalar_summary
(
name
,
averager
.
average
(
c
))
return
avg_maintain_op
return
avg_maintain_op
tensorpack/train/base.py
View file @
b9e2bd1b
...
@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
...
@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
from
..callbacks
import
StatHolder
from
..callbacks
import
StatHolder
from
..tfutils
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
create_summary
from
..tfutils.summary
import
create_summary
from
..tfutils.modelutils
import
describe_model
__all__
=
[
'Trainer'
]
__all__
=
[
'Trainer'
]
...
@@ -89,7 +88,7 @@ class Trainer(object):
...
@@ -89,7 +88,7 @@ class Trainer(object):
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
for
val
in
summary
.
value
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
if
val
.
WhichOneof
(
'value'
)
==
'simple_value'
:
val
.
tag
=
re
.
sub
(
'tower[0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
val
.
tag
=
re
.
sub
(
'tower[
p
0-9]+/'
,
''
,
val
.
tag
)
# TODO move to subclasses
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
stat_holder
.
add_stat
(
val
.
tag
,
val
.
simple_value
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
...
@@ -141,7 +140,6 @@ class Trainer(object):
...
@@ -141,7 +140,6 @@ class Trainer(object):
self
.
sess
.
close
()
self
.
sess
.
close
()
def
init_session_and_coord
(
self
):
def
init_session_and_coord
(
self
):
describe_model
()
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
coord
=
tf
.
train
.
Coordinator
()
self
.
coord
=
tf
.
train
.
Coordinator
()
...
...
tensorpack/train/multigpu.py
0 → 100644
View file @
b9e2bd1b
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
six.moves
import
zip
,
range
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..tfutils
import
*
from
.trainer
import
QueueInputTrainer
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
MultiGPUTrainer
(
QueueInputTrainer
):
""" Base class for multi-gpu training"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
super
(
MultiGPUTrainer
,
self
)
.
__init__
(
config
,
input_queue
,
predict_tower
)
self
.
dequed_inputs
=
[]
@
staticmethod
def
_average_grads
(
tower_grads
):
ret
=
[]
for
grad_and_vars
in
zip
(
*
tower_grads
):
v
=
grad_and_vars
[
0
][
1
]
try
:
grad
=
tf
.
add_n
([
x
[
0
]
for
x
in
grad_and_vars
])
/
float
(
len
(
tower_grads
))
except
:
logger
.
error
(
"Error while processing gradients of {}"
.
format
(
v
.
name
))
raise
ret
.
append
((
grad
,
v
))
return
ret
def
_multi_tower_grads
(
self
):
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
grad_list
=
[]
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
logger
.
info
(
"Building graph for training tower {}..."
.
format
(
i
))
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
self
.
dequed_inputs
.
append
(
model_inputs
)
self
.
model
.
build_graph
(
model_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
# build tower
# TODO gate_gradienst=0 seems to be faster?
grad_list
.
append
(
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
))
if
i
==
0
:
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
tf
.
get_variable_scope
()
.
reuse_variables
()
# avoid repeated summary from each device
backup
=
backup_collection
(
self
.
SUMMARY_BACKUP_KEYS
)
restore_collection
(
backup
)
return
grad_list
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
grads
=
MultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
self
.
process_grads
(
grads
)
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
())
describe_model
()
with
freeze_collection
(
self
.
SUMMARY_BACKUP_KEYS
):
self
.
_build_predict_tower
()
# [debug]: do nothing in training
#self.train_op = self.dequed_inputs[0][0] + self.dequed_inputs[1][0]
self
.
main_loop
()
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def
scale
(
grads
):
return
[(
grad
/
self
.
config
.
nr_tower
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
# use grad from the first tower for iteration in main thread
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
0
],
get_global_step_var
()),
summary_moving_average
())
describe_model
()
# prepare train_op for the rest of the towers
self
.
training_threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
f
=
lambda
op
=
train_op
:
self
.
sess
.
run
([
op
])
# avoid late-binding
th
=
LoopThread
(
f
)
th
.
pause
()
th
.
start
()
self
.
training_threads
.
append
(
th
)
self
.
async_running
=
False
with
freeze_collection
(
self
.
SUMMARY_BACKUP_KEYS
):
self
.
_build_predict_tower
()
self
.
main_loop
()
def
run_step
(
self
):
if
not
self
.
async_running
:
self
.
async_running
=
True
for
th
in
self
.
training_threads
:
# resume all threads
th
.
resume
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
run_step
()
def
_trigger_epoch
(
self
):
self
.
async_running
=
False
for
th
in
self
.
training_threads
:
th
.
pause
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
_trigger_epoch
()
tensorpack/train/trainer.py
View file @
b9e2bd1b
...
@@ -5,20 +5,16 @@
...
@@ -5,20 +5,16 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
threading
import
threading
import
time
import
time
import
copy
import
re
import
functools
from
six.moves
import
zip
from
six.moves
import
zip
from
.base
import
Trainer
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..dataflow.common
import
RepeatedData
from
..utils
import
*
from
..utils
import
*
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..tfutils
import
*
from
..tfutils
import
*
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
def
run_step
(
self
):
def
run_step
(
self
):
...
@@ -42,6 +38,7 @@ class SimpleTrainer(Trainer):
...
@@ -42,6 +38,7 @@ class SimpleTrainer(Trainer):
avg_maintain_op
)
avg_maintain_op
)
self
.
init_session_and_coord
()
self
.
init_session_and_coord
()
describe_model
()
# create an infinte data producer
# create an infinte data producer
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
self
.
main_loop
()
self
.
main_loop
()
...
@@ -76,6 +73,7 @@ class EnqueueThread(threading.Thread):
...
@@ -76,6 +73,7 @@ class EnqueueThread(threading.Thread):
self
.
queue
=
queue
self
.
queue
=
queue
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
close_op
=
self
.
queue
.
close
(
cancel_pending_enqueues
=
True
)
self
.
size_op
=
self
.
queue
.
size
()
self
.
daemon
=
True
self
.
daemon
=
True
def
run
(
self
):
def
run
(
self
):
...
@@ -86,6 +84,8 @@ class EnqueueThread(threading.Thread):
...
@@ -86,6 +84,8 @@ class EnqueueThread(threading.Thread):
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
return
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
#_, size = self.sess.run([self.op, self.size_op], feed_dict=feed)
#print size
self
.
op
.
run
(
feed_dict
=
feed
)
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
except
tf
.
errors
.
CancelledError
as
e
:
pass
pass
...
@@ -97,16 +97,16 @@ class EnqueueThread(threading.Thread):
...
@@ -97,16 +97,16 @@ class EnqueueThread(threading.Thread):
logger
.
info
(
"Enqueue Thread Exited."
)
logger
.
info
(
"Enqueue Thread Exited."
)
class
QueueInputTrainer
(
Trainer
):
class
QueueInputTrainer
(
Trainer
):
"""
""" Single GPU Trainer, takes input from a queue"""
Trainer which builds a FIFO queue for input.
Support multi GPU.
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
async
=
Fals
e
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
Non
e
):
"""
"""
:param config: a `TrainConfig` instance
:param config: a `TrainConfig` instance
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu idx to run prediction. default to be [0].
"""
"""
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
self
.
input_vars
=
self
.
model
.
get_input_vars
()
self
.
input_vars
=
self
.
model
.
get_input_vars
()
...
@@ -115,23 +115,11 @@ class QueueInputTrainer(Trainer):
...
@@ -115,23 +115,11 @@ class QueueInputTrainer(Trainer):
100
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
100
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
else
:
else
:
self
.
input_queue
=
input_queue
self
.
input_queue
=
input_queue
self
.
async
=
async
if
predict_tower
is
None
:
if
self
.
async
:
# by default, use the first training gpu for prediction
assert
self
.
config
.
nr_tower
>
1
predict_tower
=
[
0
]
self
.
dequed_inputs
=
[]
self
.
predict_tower
=
predict_tower
self
.
dequed_inputs
=
None
@
staticmethod
def
_average_grads
(
tower_grads
):
ret
=
[]
for
grad_and_vars
in
zip
(
*
tower_grads
):
v
=
grad_and_vars
[
0
][
1
]
try
:
grad
=
tf
.
add_n
([
x
[
0
]
for
x
in
grad_and_vars
])
/
float
(
len
(
tower_grads
))
except
AssertionError
:
logger
.
error
(
"Error while processing gradients of {}"
.
format
(
v
.
name
))
raise
ret
.
append
((
grad
,
v
))
return
ret
def
_get_model_inputs
(
self
):
def
_get_model_inputs
(
self
):
""" Dequeue a datapoint from input_queue and return"""
""" Dequeue a datapoint from input_queue and return"""
...
@@ -141,104 +129,58 @@ class QueueInputTrainer(Trainer):
...
@@ -141,104 +129,58 @@ class QueueInputTrainer(Trainer):
assert
len
(
ret
)
==
len
(
self
.
input_vars
)
assert
len
(
ret
)
==
len
(
self
.
input_vars
)
for
qv
,
v
in
zip
(
ret
,
self
.
input_vars
):
for
qv
,
v
in
zip
(
ret
,
self
.
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
qv
.
set_shape
(
v
.
get_shape
())
self
.
dequed_inputs
.
append
(
ret
)
return
ret
return
ret
def
_build_predict_tower
(
self
):
inputs
=
self
.
model
.
get_input_vars
()
tf
.
get_variable_scope
()
.
reuse_variables
()
for
k
in
self
.
predict_tower
:
logger
.
info
(
"Building graph for predict towerp{}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)),
\
tf
.
name_scope
(
'towerp{}'
.
format
(
k
)):
self
.
model
.
build_graph
(
inputs
,
False
)
def
_single_tower_grad
(
self
):
def
_single_tower_grad
(
self
):
""" Get grad and cost for single-tower
case
"""
""" Get grad and cost for single-tower"""
model_inputs
=
self
.
_get_model_inputs
()
self
.
dequed_inputs
=
model_inputs
=
self
.
_get_model_inputs
()
self
.
model
.
build_graph
(
model_inputs
,
True
)
self
.
model
.
build_graph
(
model_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
cost_var
=
self
.
model
.
get_cost
()
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
return
grads
return
grads
def
_multi_tower_grads
(
self
):
def
_build_enque_thread
(
self
):
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
""" create a thread that keeps filling the queue """
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
# to avoid repeated summary from each device
self
.
input_th
=
EnqueueThread
(
self
,
self
.
input_queue
,
enqueue_op
,
self
.
input_vars
)
collect_dedup
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
kept_summaries
=
{}
grad_list
=
[]
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
tf
.
name_scope
(
'tower{}'
.
format
(
i
))
as
scope
:
logger
.
info
(
"Building graph for tower {}..."
.
format
(
i
))
model_inputs
=
self
.
_get_model_inputs
()
# each tower dequeue from input queue
self
.
model
.
build_graph
(
model_inputs
,
True
)
cost_var
=
self
.
model
.
get_cost
()
# build tower
# gate_gradienst=0 seems to be faster?
grad_list
.
append
(
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
,
gate_gradients
=
0
))
if
i
==
0
:
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost_var
)
tf
.
get_variable_scope
()
.
reuse_variables
()
for
k
in
collect_dedup
:
kept_summaries
[
k
]
=
copy
.
copy
(
tf
.
get_collection
(
k
))
for
k
in
collect_dedup
:
del
tf
.
get_collection_ref
(
k
)[:]
tf
.
get_collection_ref
(
k
)
.
extend
(
kept_summaries
[
k
])
return
grad_list
def
train
(
self
):
def
train
(
self
):
enqueue_op
=
self
.
input_queue
.
enqueue
(
self
.
input_vars
)
assert
self
.
config
.
nr_tower
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
if
self
.
config
.
nr_tower
>
1
:
grad_list
=
self
.
_multi_tower_grads
()
if
not
self
.
async
:
grads
=
QueueInputTrainer
.
_average_grads
(
grad_list
)
grads
=
self
.
process_grads
(
grads
)
else
:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def
scale
(
grads
):
return
[(
grad
/
self
.
config
.
nr_tower
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
grads
=
grad_list
[
0
]
# use grad from the first tower for the main iteration
else
:
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
process_grads
(
grads
)
grads
=
self
.
process_grads
(
grads
)
describe_model
()
with
freeze_collection
(
self
.
SUMMARY_BACKUP_KEYS
):
self
.
_build_predict_tower
()
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
tf
.
group
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
())
summary_moving_average
())
if
self
.
async
:
# prepare train_op for the rest of the towers
self
.
threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
f
=
lambda
op
=
train_op
:
self
.
sess
.
run
([
op
])
# avoid late-binding
th
=
LoopThread
(
f
)
th
.
pause
()
th
.
start
()
self
.
threads
.
append
(
th
)
self
.
async_running
=
False
self
.
init_session_and_coord
()
# create a thread that keeps filling the queue
self
.
input_th
=
EnqueueThread
(
self
,
self
.
input_queue
,
enqueue_op
,
self
.
input_vars
)
self
.
extra_threads_procs
.
append
(
self
.
input_th
)
self
.
main_loop
()
self
.
main_loop
()
def
run_step
(
self
):
def
run_step
(
self
):
if
self
.
async
:
""" just run self.train_op"""
if
not
self
.
async_running
:
self
.
sess
.
run
([
self
.
train_op
])
self
.
async_running
=
True
for
th
in
self
.
threads
:
# resume all threads
th
.
resume
()
self
.
sess
.
run
([
self
.
train_op
])
# faster since train_op return None
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
# note that summary_op will take a data from the queue
if
self
.
async
:
self
.
async_running
=
False
for
th
in
self
.
threads
:
th
.
pause
()
if
self
.
summary_op
is
not
None
:
if
self
.
summary_op
is
not
None
:
summary_str
=
self
.
summary_op
.
eval
()
summary_str
=
self
.
summary_op
.
eval
()
self
.
_process_summary
(
summary_str
)
self
.
_process_summary
(
summary_str
)
...
@@ -246,33 +188,20 @@ class QueueInputTrainer(Trainer):
...
@@ -246,33 +188,20 @@ class QueueInputTrainer(Trainer):
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
"""
:param tower: return the kth predict_func
:param tower: return the kth predict_func
:returns: a predictor function
"""
"""
tower
=
tower
%
self
.
config
.
nr_tower
tower
=
self
.
predict_tower
[
tower
%
len
(
self
.
predict_tower
)]
if
self
.
config
.
nr_tower
>
1
:
logger
.
info
(
"Prepare a predictor function for tower{} ..."
.
format
(
tower
))
raw_input_vars
=
get_vars_by_names
(
input_names
)
raw_input_vars
=
get_vars_by_names
(
input_names
)
input_var_idxs
=
[
self
.
input_vars
.
index
(
v
)
for
v
in
raw_input_vars
]
output_names
=
[
'towerp{}/'
.
format
(
tower
)
+
n
for
n
in
output_names
]
dequed
=
self
.
dequed_inputs
[
tower
]
input_vars
=
[
dequed
[
k
]
for
k
in
input_var_idxs
]
if
self
.
config
.
nr_tower
>
1
:
output_names
=
[
'tower{}/'
.
format
(
tower
)
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
output_vars
=
get_vars_by_names
(
output_names
)
def
func
(
inputs
):
def
func
(
inputs
):
assert
len
(
inputs
)
==
len
(
input_vars
)
assert
len
(
inputs
)
==
len
(
raw_
input_vars
)
feed
=
dict
(
zip
(
input_vars
,
inputs
))
feed
=
dict
(
zip
(
raw_
input_vars
,
inputs
))
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
self
.
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
func
return
func
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
""" return n predicts functions evenly on each predict_tower"""
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
for
k
in
range
(
n
)]
def
AsyncMultiGPUTrainer
(
config
):
return
QueueInputTrainer
(
config
,
async
=
True
)
def
SyncMultiGPUTrainer
(
config
):
return
QueueInputTrainer
(
config
)
tensorpack/utils/serialize.py
View file @
b9e2bd1b
...
@@ -3,17 +3,17 @@
...
@@ -3,17 +3,17 @@
# File: serialize.py
# File: serialize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
#
import msgpack
import
msgpack
#
import msgpack_numpy
import
msgpack_numpy
#
msgpack_numpy.patch()
msgpack_numpy
.
patch
()
import
dill
#
import dill
__all__
=
[
'loads'
,
'dumps'
]
__all__
=
[
'loads'
,
'dumps'
]
def
dumps
(
obj
):
def
dumps
(
obj
):
return
dill
.
dumps
(
obj
)
#
return dill.dumps(obj)
#
return msgpack.dumps(obj, use_bin_type=True)
return
msgpack
.
dumps
(
obj
,
use_bin_type
=
True
)
def
loads
(
buf
):
def
loads
(
buf
):
return
dill
.
loads
(
buf
)
#
return dill.loads(buf)
#
return msgpack.loads(buf)
return
msgpack
.
loads
(
buf
)
tensorpack/utils/timer.py
View file @
b9e2bd1b
...
@@ -37,7 +37,7 @@ def print_total_timer():
...
@@ -37,7 +37,7 @@ def print_total_timer():
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
return
return
for
k
,
v
in
six
.
iteritems
(
_TOTAL_TIMER_DATA
):
for
k
,
v
in
six
.
iteritems
(
_TOTAL_TIMER_DATA
):
logger
.
info
(
"Total Time: {} -> {} sec, {} times"
.
format
(
logger
.
info
(
"Total Time: {} -> {} sec, {} times
, {} sec/time
"
.
format
(
k
,
v
.
sum
,
v
.
count
))
k
,
v
.
sum
,
v
.
count
,
v
.
average
))
atexit
.
register
(
print_total_timer
)
atexit
.
register
(
print_total_timer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment