Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6306da7e
Commit
6306da7e
authored
Nov 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
small refactor in train/base
parent
8bfab811
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
39 additions
and
55 deletions
+39
-55
.gitignore
.gitignore
+1
-0
README.md
README.md
+2
-0
examples/OpenAIGym/train-atari.py
examples/OpenAIGym/train-atari.py
+1
-1
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+1
-1
scripts/plot-point.py
scripts/plot-point.py
+7
-6
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+1
-8
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+2
-1
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+2
-0
tensorpack/train/base.py
tensorpack/train/base.py
+16
-32
tensorpack/train/config.py
tensorpack/train/config.py
+4
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+0
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+0
-2
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+2
-1
No files found.
.gitignore
View file @
6306da7e
...
...
@@ -73,3 +73,4 @@ model-*
checkpoint
*.json
*.prototxt
snippet
README.md
View file @
6306da7e
# tensorpack
Neural Network Toolbox on TensorFlow
Still in development. Underlying design may change.
See some
[
examples
](
examples
)
to learn about the framework.
You can actually train them and reproduce the performance... not just to see how to write code.
...
...
examples/OpenAIGym/train-atari.py
View file @
6306da7e
...
...
@@ -204,9 +204,9 @@ def get_config():
HumanHyperParamSetter
(
'entropy_beta'
),
HumanHyperParamSetter
(
'explore_factor'
),
master
,
StartProcOrThread
(
master
)
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'logits'
]),
2
),
]),
extra_threads_procs
=
[
master
],
session_config
=
get_default_sess_config
(
0.5
),
model
=
M
,
step_per_epoch
=
STEP_PER_EPOCH
,
...
...
examples/ResNet/cifar10-resnet.py
View file @
6306da7e
...
...
@@ -22,7 +22,7 @@ Identity Mappings in Deep Residual Networks, arxiv:1603.05027
I can reproduce the results on 2 TitanX for
n=5, about 7.1
%
val error after 67k steps (8.6 step/s)
n=18, about 5.
8
%
val error after 80k steps (2.6 step/s)
n=18, about 5.
9
%
val error after 80k steps (2.6 step/s)
n=30: a 182-layer network, about 5.6
%
val error after 51k steps (1.55 step/s)
This model uses the whole training set instead of a train-val split.
"""
...
...
scripts/plot-point.py
View file @
6306da7e
...
...
@@ -22,6 +22,7 @@ import matplotlib.font_manager as fontm
import
argparse
,
sys
from
collections
import
defaultdict
from
itertools
import
chain
import
six
from
matplotlib
import
rc
#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
...
...
@@ -52,11 +53,9 @@ def get_args():
help
=
'title of the graph'
,
default
=
''
)
parser
.
add_argument
(
'--xlabel'
,
help
=
'x label'
,
default
=
'x'
)
help
=
'x label'
,
type
=
six
.
text_type
)
parser
.
add_argument
(
'--ylabel'
,
help
=
'y label'
,
default
=
'y'
)
help
=
'y label'
,
type
=
six
.
text_type
)
parser
.
add_argument
(
'-s'
,
'--scale'
,
help
=
'scale of each y, separated by comma'
)
parser
.
add_argument
(
'--annotate-maximum'
,
...
...
@@ -215,8 +214,10 @@ def do_plot(data_xs, data_ys):
if
args
.
annotate_maximum
or
args
.
annotate_minimum
:
annotate_min_max
(
truncate_data_x
,
data_y
,
ax
)
plt
.
xlabel
(
args
.
xlabel
.
decode
(
'utf-8'
),
fontsize
=
'xx-large'
)
plt
.
ylabel
(
args
.
ylabel
.
decode
(
'utf-8'
),
fontsize
=
'xx-large'
)
if
args
.
xlabel
:
plt
.
xlabel
(
args
.
xlabel
,
fontsize
=
'xx-large'
)
if
args
.
ylabel
:
plt
.
ylabel
(
args
.
ylabel
,
fontsize
=
'xx-large'
)
plt
.
legend
(
loc
=
'best'
,
fontsize
=
'xx-large'
)
# adjust maxx
...
...
tensorpack/callbacks/base.py
View file @
6306da7e
...
...
@@ -56,13 +56,6 @@ class Callback(object):
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
"""
@
property
def
global_step
(
self
):
"""
Access the global step value of this training.
"""
return
self
.
trainer
.
global_step
def
trigger_epoch
(
self
):
"""
Triggered after every epoch.
...
...
@@ -95,7 +88,7 @@ class ProxyCallback(Callback):
self
.
cb
.
trigger_epoch
()
def
__str__
(
self
):
return
str
(
self
.
cb
)
return
"Proxy-"
+
str
(
self
.
cb
)
class
PeriodicCallback
(
ProxyCallback
):
"""
...
...
tensorpack/callbacks/common.py
View file @
6306da7e
...
...
@@ -9,6 +9,7 @@ import re
from
.base
import
Callback
from
..utils
import
logger
from
..tfutils.varmanip
import
get_savename_from_varname
from
..tfutils
import
get_global_step
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
...
...
@@ -72,7 +73,7 @@ due to an alternative in a different tower".format(v.name, var_dict[name].name))
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
global_step
,
global_step
=
get_global_step
()
,
write_meta_graph
=
False
)
# create a symbolic link for the latest model
...
...
tensorpack/dataflow/raw.py
View file @
6306da7e
...
...
@@ -22,6 +22,8 @@ class FakeData(RNGDataFlow):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
:param random: whether to randomly generate data every iteration. note
that only generating the data could be time-consuming!
"""
super
(
FakeData
,
self
)
.
__init__
()
self
.
shapes
=
shapes
...
...
tensorpack/train/base.py
View file @
6306da7e
...
...
@@ -13,7 +13,6 @@ import tensorflow as tf
from
.config
import
TrainConfig
from
..utils
import
logger
,
get_tqdm_kwargs
from
..utils.timer
import
timed_operation
from
..utils.concurrency
import
start_proc_mask_signal
from
..callbacks
import
StatHolder
from
..tfutils
import
get_global_step
,
get_global_step_var
from
..tfutils.summary
import
create_summary
...
...
@@ -32,7 +31,6 @@ class Trainer(object):
summary_writer: a `tf.SummaryWriter`
config: a `TrainConfig`
model: a `ModelDesc`
global_step: a `int`
"""
__metaclass__
=
ABCMeta
...
...
@@ -44,7 +42,7 @@ class Trainer(object):
self
.
config
=
config
self
.
model
=
config
.
model
self
.
model
.
get_input_vars
()
# ensure they are present
self
.
_extra_threads_procs
=
config
.
extra_threads_procs
self
.
init_session_and_coord
()
@
abstractmethod
def
train
(
self
):
...
...
@@ -84,15 +82,6 @@ class Trainer(object):
""" This is called right after all steps in an epoch are finished"""
pass
def
_init_summary
(
self
):
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"Please use logger.set_logger_dir at the beginning of your script."
)
self
.
summary_writer
=
tf
.
train
.
SummaryWriter
(
logger
.
LOG_DIR
,
graph
=
self
.
sess
.
graph
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
# create an empty StatHolder
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
def
_process_summary
(
self
,
summary_str
):
summary
=
tf
.
Summary
.
FromString
(
summary_str
)
for
val
in
summary
.
value
:
...
...
@@ -107,31 +96,39 @@ class Trainer(object):
get_global_step
())
self
.
stat_holder
.
add_stat
(
name
,
val
)
def
main_loop
(
self
):
def
finalize_graph
(
self
):
# some final operations that might modify the graph
get_global_step_var
()
# ensure there is such var, before finalizing the graph
logger
.
info
(
"Setup callbacks ..."
)
callbacks
=
self
.
config
.
callbacks
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_init_summary
()
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"logger directory wasn't set!"
)
self
.
summary_writer
=
tf
.
train
.
SummaryWriter
(
logger
.
LOG_DIR
,
graph
=
self
.
sess
.
graph
)
self
.
summary_op
=
tf
.
merge_all_summaries
()
# create an empty StatHolder
self
.
stat_holder
=
StatHolder
(
logger
.
LOG_DIR
)
logger
.
info
(
"Initializing graph variables ..."
)
self
.
sess
.
run
(
tf
.
initialize_all_variables
())
self
.
config
.
session_init
.
init
(
self
.
sess
)
tf
.
get_default_graph
()
.
finalize
()
self
.
_start_concurrency
()
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
def
main_loop
(
self
):
self
.
finalize_graph
()
callbacks
=
self
.
config
.
callbacks
with
self
.
sess
.
as_default
():
try
:
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
logger
.
info
(
"Start training with global_step={}"
.
format
(
get_global_step
()))
callbacks
.
before_train
()
for
self
.
epoch_num
in
range
(
self
.
config
.
starting_epoch
,
self
.
config
.
max_epoch
+
1
):
with
timed_operation
(
'Epoch {} (global_step {})'
.
format
(
self
.
epoch_num
,
self
.
global_step
+
self
.
config
.
step_per_epoch
)):
self
.
epoch_num
,
get_global_step
()
+
self
.
config
.
step_per_epoch
)):
for
step
in
tqdm
.
trange
(
self
.
config
.
step_per_epoch
,
**
get_tqdm_kwargs
(
leave
=
True
)):
...
...
@@ -139,7 +136,6 @@ class Trainer(object):
return
self
.
run_step
()
# implemented by subclass
#callbacks.trigger_step() # not useful?
self
.
global_step
+=
1
self
.
trigger_epoch
()
except
StopTraining
:
logger
.
info
(
"Training was stopped."
)
...
...
@@ -155,18 +151,6 @@ class Trainer(object):
self
.
sess
=
tf
.
Session
(
config
=
self
.
config
.
session_config
)
self
.
coord
=
tf
.
train
.
Coordinator
()
def
_start_concurrency
(
self
):
"""
Run all threads before starting training
"""
logger
.
info
(
"Starting all threads & procs ..."
)
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
with
self
.
sess
.
as_default
():
# avoid sigint get handled by other processes
start_proc_mask_signal
(
self
.
_extra_threads_procs
)
def
process_grads
(
self
,
grads
):
g
=
[]
for
grad
,
var
in
grads
:
...
...
tensorpack/train/config.py
View file @
6306da7e
...
...
@@ -32,7 +32,6 @@ class TrainConfig(object):
:param max_epoch: maximum number of epoch to run training. default to inf
:param nr_tower: int. number of training towers. default to 1.
:param tower: list of training towers in relative id. default to `range(nr_tower)` if nr_tower is given.
:param extra_threads_procs: list of `Startable` threads or processes
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
...
@@ -72,6 +71,10 @@ class TrainConfig(object):
self
.
tower
=
[
0
]
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
if
self
.
extra_threads_procs
:
logger
.
warn
(
"[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs"
)
from
..callbacks.concurrency
import
StartProcOrThread
self
.
callbacks
.
cbs
.
append
(
StartProcOrThread
(
self
.
extra_threads_procs
))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
...
...
tensorpack/train/multigpu.py
View file @
6306da7e
...
...
@@ -73,7 +73,6 @@ class MultiGPUTrainer(QueueInputTrainer):
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
...
...
@@ -92,7 +91,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer):
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
):
def
train
(
self
):
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grad_list
=
self
.
_multi_tower_grads
()
...
...
tensorpack/train/trainer.py
View file @
6306da7e
...
...
@@ -76,7 +76,6 @@ class SimpleTrainer(Trainer):
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
avg_maintain_op
)
self
.
init_session_and_coord
()
describe_model
()
# create an infinte data producer
self
.
config
.
dataset
.
reset_state
()
...
...
@@ -196,7 +195,6 @@ class QueueInputTrainer(Trainer):
def
train
(
self
):
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
self
.
init_session_and_coord
()
self
.
_build_enque_thread
()
grads
=
self
.
_single_tower_grad
()
...
...
tensorpack/utils/loadcaffe.py
View file @
6306da7e
...
...
@@ -113,9 +113,10 @@ def get_caffe_pb():
caffe_pb_file
=
os
.
path
.
join
(
dir
,
'caffe_pb2.py'
)
if
not
os
.
path
.
isfile
(
caffe_pb_file
):
proto_path
=
download
(
CAFFE_PROTO_URL
,
dir
)
assert
os
.
path
.
isfile
(
os
.
path
.
join
(
dir
,
'caffe.proto'
))
ret
=
os
.
system
(
'cd {} && protoc caffe.proto --python_out .'
.
format
(
dir
))
assert
ret
==
0
,
\
"
caffe proto compilation failed! Did you install protoc?
"
"
Command `protoc caffe.proto --python_out .` failed!
"
import
imp
return
imp
.
load_source
(
'caffepb'
,
caffe_pb_file
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment