Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fa988c58
Commit
fa988c58
authored
Dec 29, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
predictor, init/restore session
parent
ee0bca2d
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
131 additions
and
92 deletions
+131
-92
example_cifar10.py
example_cifar10.py
+12
-15
example_mnist.py
example_mnist.py
+12
-13
tensorpack/infer.py
tensorpack/infer.py
+31
-44
tensorpack/train.py
tensorpack/train.py
+4
-6
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+13
-0
tensorpack/utils/callback.py
tensorpack/utils/callback.py
+11
-2
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+2
-0
tensorpack/utils/modelutils.py
tensorpack/utils/modelutils.py
+0
-12
tensorpack/utils/sessinit.py
tensorpack/utils/sessinit.py
+46
-0
No files found.
example_cifar10.py
View file @
fa988c58
...
@@ -36,22 +36,22 @@ def get_model(inputs, is_training):
...
@@ -36,22 +36,22 @@ def get_model(inputs, is_training):
#[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
#[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
#num_threads=2, enqueue_many=False)
#num_threads=2, enqueue_many=False)
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
64
,
kernel_shape
=
5
,
padding
=
'SAME'
)
l
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
64
,
kernel_shape
=
5
,
padding
=
'SAME'
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
3
,
stride
=
2
,
padding
=
'SAME'
)
l
=
MaxPooling
(
'pool0'
,
l
,
3
,
stride
=
2
,
padding
=
'SAME'
)
norm0
=
tf
.
nn
.
lrn
(
pool0
,
4
,
bias
=
1.0
,
alpha
=
0.001
/
9.0
,
beta
=
0.75
,
name
=
'norm0'
)
l
=
tf
.
nn
.
lrn
(
l
,
4
,
bias
=
1.0
,
alpha
=
0.001
/
9.0
,
beta
=
0.75
,
name
=
'norm0'
)
conv1
=
Conv2D
(
'conv1'
,
norm0
,
out_channel
=
64
,
kernel_shape
=
5
,
padding
=
'SAME'
)
l
=
Conv2D
(
'conv1'
,
l
,
out_channel
=
64
,
kernel_shape
=
5
,
padding
=
'SAME'
)
norm1
=
tf
.
nn
.
lrn
(
conv1
,
4
,
bias
=
1.0
,
alpha
=
0.001
/
9.0
,
beta
=
0.75
,
name
=
'norm1'
)
l
=
tf
.
nn
.
lrn
(
l
,
4
,
bias
=
1.0
,
alpha
=
0.001
/
9.0
,
beta
=
0.75
,
name
=
'norm1'
)
pool1
=
MaxPooling
(
'pool1'
,
norm1
,
3
,
stride
=
2
,
padding
=
'SAME'
)
l
=
MaxPooling
(
'pool1'
,
l
,
3
,
stride
=
2
,
padding
=
'SAME'
)
fc0
=
FullyConnected
(
'fc0'
,
pool1
,
384
)
l
=
FullyConnected
(
'fc0'
,
l
,
384
)
fc1
=
FullyConnected
(
'fc1'
,
fc0
,
out_dim
=
192
)
l
=
FullyConnected
(
'fc1'
,
l
,
out_dim
=
192
)
# fc will have activation summary by default. disable this for the output layer
# fc will have activation summary by default. disable this for the output layer
fc2
=
FullyConnected
(
'fc2'
,
fc1
,
out_dim
=
10
,
summary_activation
=
False
,
nl
=
tf
.
identity
)
logits
=
FullyConnected
(
'fc2'
,
l
,
out_dim
=
10
,
summary_activation
=
False
,
nl
=
tf
.
identity
)
prob
=
tf
.
nn
.
softmax
(
fc2
,
name
=
'output'
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'output'
)
y
=
one_hot
(
label
,
10
)
y
=
one_hot
(
label
,
10
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
fc2
,
y
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
logits
,
y
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
cost
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
cost
)
...
@@ -88,11 +88,8 @@ def get_config():
...
@@ -88,11 +88,8 @@ def get_config():
#step_per_epoch = 20
#step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20)
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config
=
tf
.
ConfigProto
()
sess_config
=
get_default_sess_config
()
sess_config
.
device_count
[
'GPU'
]
=
1
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
sess_config
.
gpu_options
.
allocator_type
=
'BFC'
sess_config
.
allow_soft_placement
=
True
# prepare model
# prepare model
input_vars
=
[
input_vars
=
[
...
...
example_mnist.py
View file @
fa988c58
...
@@ -51,21 +51,21 @@ def get_model(inputs, is_training):
...
@@ -51,21 +51,21 @@ def get_model(inputs, is_training):
#[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
#[image, label], BATCH_SIZE, CAPACITY, MIN_AFTER_DEQUEUE,
#num_threads=2, enqueue_many=False)
#num_threads=2, enqueue_many=False)
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
)
l
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
l
=
MaxPooling
(
'pool0'
,
l
,
2
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
l
=
Conv2D
(
'conv1'
,
l
,
out_channel
=
40
,
kernel_shape
=
3
)
pool1
=
MaxPooling
(
'pool1'
,
conv1
,
2
)
l
=
MaxPooling
(
'pool1'
,
l
,
2
)
fc0
=
FullyConnected
(
'fc0'
,
pool1
,
1024
)
l
=
FullyConnected
(
'fc0'
,
l
,
1024
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
l
=
tf
.
nn
.
dropout
(
l
,
keep_prob
)
# fc will have activation summary by default. disable this for the output layer
# fc will have activation summary by default. disable this for the output layer
fc1
=
FullyConnected
(
'fc1'
,
fc0
,
out_dim
=
10
,
logits
=
FullyConnected
(
'fc1'
,
l
,
out_dim
=
10
,
summary_activation
=
False
,
nl
=
tf
.
identity
)
summary_activation
=
False
,
nl
=
tf
.
identity
)
prob
=
tf
.
nn
.
softmax
(
fc1
,
name
=
'output'
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'output'
)
y
=
one_hot
(
label
,
10
)
y
=
one_hot
(
label
,
10
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
fc1
,
y
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
logits
,
y
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
cost
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
cost
)
...
@@ -101,11 +101,8 @@ def get_config():
...
@@ -101,11 +101,8 @@ def get_config():
#step_per_epoch = 20
#step_per_epoch = 20
#dataset_test = FixedSizeData(dataset_test, 20)
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config
=
tf
.
ConfigProto
()
sess_config
=
get_default_sess_config
()
sess_config
.
device_count
[
'GPU'
]
=
1
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
sess_config
.
gpu_options
.
allocator_type
=
'BFC'
sess_config
.
allow_soft_placement
=
True
# prepare model
# prepare model
input_vars
=
[
input_vars
=
[
...
@@ -116,6 +113,8 @@ def get_config():
...
@@ -116,6 +113,8 @@ def get_config():
]
]
input_queue
=
tf
.
RandomShuffleQueue
(
input_queue
=
tf
.
RandomShuffleQueue
(
100
,
50
,
[
x
.
dtype
for
x
in
input_vars
],
name
=
'queue'
)
100
,
50
,
[
x
.
dtype
for
x
in
input_vars
],
name
=
'queue'
)
#input_queue = tf.FIFOQueue(
#100, [x.dtype for x in input_vars], name='queue')
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
lr
=
tf
.
train
.
exponential_decay
(
...
...
tensorpack/infer.py
View file @
fa988c58
...
@@ -9,66 +9,53 @@ import argparse
...
@@ -9,66 +9,53 @@ import argparse
import
numpy
as
np
import
numpy
as
np
from
utils
import
*
from
utils
import
*
from
utils.modelutils
import
describe_model
,
restore_params
from
utils.modelutils
import
describe_model
from
utils
import
logger
from
utils
import
logger
from
dataflow
import
DataFlow
from
dataflow
import
DataFlow
def
start_infer
(
config
):
def
get_predict_func
(
config
):
"""
"""
Args:
Args:
config: a tensorpack config dictionary
config: a tensorpack config dictionary
Returns:
a function that takes a list of inputs to run the model
"""
"""
dataset
=
config
[
'dataset'
]
assert
isinstance
(
dataset
,
DataFlow
),
dataset
.
__class__
# a tf.ConfigProto instance
sess_config
=
config
.
get
(
'session_config'
,
None
)
sess_config
=
config
.
get
(
'session_config'
,
None
)
if
sess_config
is
None
:
sess_config
=
get_default_sess_config
()
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
# TODO callback should have trigger_step and trigger_end?
sess_init
=
config
[
'session_init'
]
callback
=
config
[
'callback'
]
# restore saved params
params
=
config
.
get
(
'restore_params'
,
{})
# input/output variables
# input/output variables
input_vars
=
config
[
'inputs'
]
input_vars
=
config
[
'inputs'
]
get_model_func
=
config
[
'get_model_func'
]
get_model_func
=
config
[
'get_model_func'
]
output_vars
,
cost_var
=
get_model_func
(
input_vars
,
is_training
=
False
)
output_vars
,
cost_var
=
get_model_func
(
input_vars
,
is_training
=
False
)
# build graph
G
=
tf
.
get_default_graph
()
G
.
add_to_collection
(
FORWARD_FUNC_KEY
,
get_model_func
)
for
v
in
input_vars
:
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
describe_model
()
describe_model
()
sess
=
tf
.
Session
(
config
=
sess_config
)
sess
=
tf
.
Session
(
config
=
sess_config
)
sess
.
run
(
tf
.
initialize_all_variables
())
sess_init
.
init
(
sess
)
restore_params
(
sess
,
params
)
def
run_input
(
dp
):
feed
=
dict
(
zip
(
input_vars
,
dp
))
with
sess
.
as_default
():
results
=
sess
.
run
(
with
timed_operation
(
'running one batch'
):
[
cost_var
]
+
output_vars
,
feed_dict
=
feed
)
for
dp
in
dataset
.
get_data
():
cost
=
results
[
0
]
feed
=
dict
(
zip
(
input_vars
,
dp
))
outputs
=
results
[
1
:]
fetches
=
[
cost_var
]
+
output_vars
return
cost
,
outputs
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
return
run_input
cost
=
results
[
0
]
outputs
=
results
[
1
:]
class
DatasetPredictor
(
object
):
prob
=
outputs
[
0
]
def
__init__
(
self
,
predict_config
,
dataset
):
callback
(
dp
,
outputs
,
cost
)
assert
isinstance
(
dataset
,
DataFlow
)
self
.
ds
=
dataset
def
main
(
get_config_func
):
self
.
predict_func
=
get_predict_func
(
predict_config
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
def
get_result
(
self
):
args
=
parser
.
parse_args
()
""" a generator to return prediction for each data"""
if
args
.
gpu
:
for
dp
in
self
.
ds
.
get_data
():
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
yield
self
.
predict_func
(
dp
)
with
tf
.
Graph
()
.
as_default
():
def
get_all_result
(
self
):
config
=
get_config_func
()
return
list
(
self
.
get_result
())
start_infer
(
config
)
tensorpack/train.py
View file @
fa988c58
...
@@ -10,7 +10,8 @@ import argparse
...
@@ -10,7 +10,8 @@ import argparse
from
utils
import
*
from
utils
import
*
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.summary
import
summary_moving_average
from
utils.summary
import
summary_moving_average
from
utils.modelutils
import
restore_params
,
describe_model
from
utils.modelutils
import
describe_model
from
utils.sessinit
import
NewSession
from
utils
import
logger
from
utils
import
logger
from
dataflow
import
DataFlow
from
dataflow
import
DataFlow
...
@@ -53,8 +54,7 @@ def start_train(config):
...
@@ -53,8 +54,7 @@ def start_train(config):
sess_config
=
config
.
get
(
'session_config'
,
None
)
sess_config
=
config
.
get
(
'session_config'
,
None
)
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
# restore saved params
sess_init
=
config
.
get
(
'session_init'
,
NewSession
())
params
=
config
.
get
(
'restore_params'
,
{})
# input/output variables
# input/output variables
input_vars
=
config
[
'inputs'
]
input_vars
=
config
[
'inputs'
]
...
@@ -83,9 +83,7 @@ def start_train(config):
...
@@ -83,9 +83,7 @@ def start_train(config):
train_op
=
get_train_op
(
optimizer
,
cost_var
)
train_op
=
get_train_op
(
optimizer
,
cost_var
)
sess
=
tf
.
Session
(
config
=
sess_config
)
sess
=
tf
.
Session
(
config
=
sess_config
)
sess
.
run
(
tf
.
initialize_all_variables
())
sess_init
.
init
(
sess
)
restore_params
(
sess
,
params
)
# start training:
# start training:
coord
=
tf
.
train
.
Coordinator
()
coord
=
tf
.
train
.
Coordinator
()
...
...
tensorpack/utils/__init__.py
View file @
fa988c58
...
@@ -53,3 +53,16 @@ def create_test_session():
...
@@ -53,3 +53,16 @@ def create_test_session():
with
create_test_graph
():
with
create_test_graph
():
with
tf
.
Session
()
as
sess
:
with
tf
.
Session
()
as
sess
:
yield
sess
yield
sess
def
get_default_sess_config
():
"""
Return a better config to use as default.
Tensorflow default session config consume too much resources
"""
conf
=
tf
.
ConfigProto
()
conf
.
device_count
[
'GPU'
]
=
1
conf
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.8
conf
.
gpu_options
.
allocator_type
=
'BFC'
conf
.
allow_soft_placement
=
True
return
conf
tensorpack/utils/callback.py
View file @
fa988c58
...
@@ -59,12 +59,16 @@ class PeriodicCallback(Callback):
...
@@ -59,12 +59,16 @@ class PeriodicCallback(Callback):
pass
pass
class
PeriodicSaver
(
PeriodicCallback
):
class
PeriodicSaver
(
PeriodicCallback
):
def
__init__
(
self
,
period
=
1
):
def
__init__
(
self
,
period
=
1
,
keep_recent
=
50
,
keep_freq
=
0.5
):
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
path
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
'model'
)
self
.
keep_recent
=
keep_recent
self
.
keep_freq
=
keep_freq
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
99999
)
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
self
.
keep_recent
,
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
def
_trigger
(
self
):
def
_trigger
(
self
):
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
...
@@ -142,6 +146,11 @@ class TrainCallbacks(Callback):
...
@@ -142,6 +146,11 @@ class TrainCallbacks(Callback):
tm
.
log
()
tm
.
log
()
class
TestCallbacks
(
Callback
):
class
TestCallbacks
(
Callback
):
"""
Hold callbacks to be run in testing graph.
Will set a context with testing graph and testing session, for
each test-time callback to run
"""
def
__init__
(
self
,
callbacks
):
def
__init__
(
self
,
callbacks
):
self
.
cbs
=
callbacks
self
.
cbs
=
callbacks
...
...
tensorpack/utils/logger.py
View file @
fa988c58
...
@@ -7,6 +7,7 @@ import logging
...
@@ -7,6 +7,7 @@ import logging
import
os
import
os
import
os.path
import
os.path
from
termcolor
import
colored
from
termcolor
import
colored
from
.utils
import
mkdir_p
__all__
=
[]
__all__
=
[]
...
@@ -56,5 +57,6 @@ LOG_DIR = "train_log"
...
@@ -56,5 +57,6 @@ LOG_DIR = "train_log"
def
set_logger_dir
(
dirname
):
def
set_logger_dir
(
dirname
):
global
LOG_DIR
global
LOG_DIR
LOG_DIR
=
dirname
LOG_DIR
=
dirname
mkdir_p
(
LOG_DIR
)
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
tensorpack/utils/modelutils.py
View file @
fa988c58
...
@@ -6,18 +6,6 @@
...
@@ -6,18 +6,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
logger
import
logger
def
restore_params
(
sess
,
params
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
for
name
,
value
in
params
.
iteritems
():
try
:
var
=
var_dict
[
name
]
except
(
ValueError
,
KeyError
):
logger
.
warn
(
"Param {} not found in this graph"
.
format
(
name
))
continue
logger
.
info
(
"Restoring param {}"
.
format
(
name
))
sess
.
run
(
var
.
assign
(
value
))
def
describe_model
():
def
describe_model
():
""" describe the current model parameters"""
""" describe the current model parameters"""
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
...
...
tensorpack/utils/sessinit.py
0 → 100644
View file @
fa988c58
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: sessinit.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
abc
import
abstractmethod
import
tensorflow
as
tf
from
.
import
logger
class
SessionInit
(
object
):
@
abstractmethod
def
init
(
self
,
sess
):
""" Method to initialize a session"""
class
NewSession
(
SessionInit
):
def
init
(
self
,
sess
):
sess
.
run
(
tf
.
initialize_all_variables
())
class
SaverRestore
(
SessionInit
):
def
__init__
(
self
,
model_path
):
self
.
set_path
(
model_path
)
def
init
(
self
,
sess
):
saver
=
tf
.
train
.
Saver
()
saver
.
restore
(
sess
,
self
.
path
)
logger
.
info
(
"Restore checkpoint from {}"
.
format
(
ckpt
.
model_checkpoint_path
))
def
set_path
(
self
,
model_path
):
self
.
path
=
model_path
class
ParamRestore
(
SessionInit
):
def
__init__
(
self
,
param_dict
):
self
.
prms
=
param_dict
def
init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
var_dict
=
dict
([
v
.
name
,
v
]
for
v
in
variables
)
for
name
,
value
in
self
.
prms
.
iteritems
():
try
:
var
=
var_dict
[
name
]
except
(
ValueError
,
KeyError
):
logger
.
warn
(
"Param {} not found in this graph"
.
format
(
name
))
continue
logger
.
info
(
"Restoring param {}"
.
format
(
name
))
sess
.
run
(
var
.
assign
(
value
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment