Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9bb0b8f6
Commit
9bb0b8f6
authored
Dec 30, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
trainconfig and predictconfig
parent
4185a222
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
294 additions
and
81 deletions
+294
-81
example_alexnet.py
example_alexnet.py
+164
-0
example_cifar10.py
example_cifar10.py
+6
-8
example_mnist.py
example_mnist.py
+7
-9
tensorpack/predict.py
tensorpack/predict.py
+50
-24
tensorpack/train.py
tensorpack/train.py
+61
-39
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+6
-1
No files found.
example_alexnet.py
0 → 100755
View file @
9bb0b8f6
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: example_alexnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
numpy
as
np
import
os
import
argparse
import
cPickle
as
pkl
from
tensorpack.train
import
TrainConfig
,
start_train
from
tensorpack.predict
import
PredictConfig
,
get_predict_func
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.symbolic_functions
import
*
from
tensorpack.utils.summary
import
*
from
tensorpack.utils.callback
import
*
from
tensorpack.dataflow
import
*
BATCH_SIZE
=
10
MIN_AFTER_DEQUEUE
=
500
CAPACITY
=
MIN_AFTER_DEQUEUE
+
3
*
BATCH_SIZE
def
get_model
(
inputs
,
is_training
):
# img: 227x227x3
is_training
=
bool
(
is_training
)
keep_prob
=
tf
.
constant
(
0.5
if
is_training
else
1.0
)
image
,
label
=
inputs
l
=
Conv2D
(
'conv1'
,
image
,
out_channel
=
96
,
kernel_shape
=
11
,
stride
=
4
,
padding
=
'VALID'
)
l
=
tf
.
nn
.
lrn
(
l
,
2
,
bias
=
1.0
,
alpha
=
2e-5
,
beta
=
0.75
,
name
=
'norm1'
)
l
=
MaxPooling
(
'pool1'
,
l
,
3
,
stride
=
2
,
padding
=
'VALID'
)
l
=
Conv2D
(
'conv2'
,
l
,
out_channel
=
256
,
kernel_shape
=
5
,
padding
=
'SAME'
,
split
=
2
)
l
=
tf
.
nn
.
lrn
(
l
,
2
,
bias
=
1.0
,
alpha
=
2e-5
,
beta
=
0.75
,
name
=
'norm2'
)
l
=
MaxPooling
(
'pool2'
,
l
,
3
,
stride
=
2
,
padding
=
'VALID'
)
l
=
Conv2D
(
'conv3'
,
l
,
out_channel
=
384
,
kernel_shape
=
3
,
padding
=
'SAME'
)
l
=
Conv2D
(
'conv4'
,
l
,
out_channel
=
384
,
kernel_shape
=
3
,
padding
=
'SAME'
,
split
=
2
)
l
=
Conv2D
(
'conv5'
,
l
,
out_channel
=
256
,
kernel_shape
=
3
,
padding
=
'SAME'
,
split
=
2
)
l
=
MaxPooling
(
'pool3'
,
l
,
3
,
stride
=
2
,
padding
=
'VALID'
)
l
=
FullyConnected
(
'fc6'
,
l
,
4096
)
l
=
FullyConnected
(
'fc7'
,
l
,
out_dim
=
4096
)
# fc will have activation summary by default. disable this for the output layer
logits
=
FullyConnected
(
'fc8'
,
l
,
out_dim
=
1000
,
summary_activation
=
False
,
nl
=
tf
.
identity
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'output'
)
y
=
one_hot
(
label
,
1000
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
logits
,
y
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
cost
)
# compute the number of failed samples, for ValidationError to use at test time
wrong
=
tf
.
not_equal
(
tf
.
cast
(
tf
.
argmax
(
prob
,
1
),
tf
.
int32
),
label
)
wrong
=
tf
.
cast
(
wrong
,
tf
.
float32
)
nr_wrong
=
tf
.
reduce_sum
(
wrong
,
name
=
'wrong'
)
# monitor training error
tf
.
add_to_collection
(
SUMMARY_VARS_KEY
,
tf
.
reduce_mean
(
wrong
,
name
=
'train_error'
))
# weight decay on all W of fc layers
wd_cost
=
tf
.
mul
(
1e-4
,
regularize_cost
(
'fc.*/W'
,
tf
.
nn
.
l2_loss
),
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
# this won't work with multigpu
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
return
[
prob
,
nr_wrong
],
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'cost'
)
def
get_config
():
log_dir
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
logger
.
set_logger_dir
(
log_dir
)
dataset_train
=
FakeData
([(
227
,
227
,
3
),
tuple
()],
10
)
dataset_train
=
BatchData
(
dataset_train
,
10
)
step_per_epoch
=
3
sess_config
=
get_default_sess_config
()
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
# prepare model
input_vars
=
[
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
227
,
227
,
3
),
name
=
'input'
),
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
]
input_queue
=
tf
.
RandomShuffleQueue
(
10
,
3
,
[
x
.
dtype
for
x
in
input_vars
],
name
=
'queue'
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-8
,
global_step
=
get_global_step_var
(),
decay_steps
=
dataset_train
.
size
()
*
50
,
decay_rate
=
0.1
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
param_dict
=
np
.
load
(
'alexnet1.npy'
)
.
item
()
return
TrainConfig
(
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callback
=
Callbacks
([
SummaryWriter
(),
PeriodicSaver
(),
#ValidationError(dataset_test, prefix='test'),
]),
session_config
=
sess_config
,
inputs
=
input_vars
,
input_queue
=
input_queue
,
get_model_func
=
get_model
,
step_per_epoch
=
step_per_epoch
,
session_init
=
ParamRestore
(
param_dict
),
max_epoch
=
100
,
)
def
run_test
(
path
):
input_vars
=
[
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
227
,
227
,
3
),
name
=
'input'
),
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
]
param_dict
=
np
.
load
(
path
)
.
item
()
pred_config
=
PredictConfig
(
inputs
=
input_vars
,
get_model_func
=
get_model
,
session_init
=
ParamRestore
(
param_dict
),
output_var_names
=
[
'output:0'
]
# output:0 is the probability distribution
)
predict_func
=
get_predict_func
(
pred_config
)
import
cv2
im
=
cv2
.
imread
(
'cat.jpg'
)
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im
=
cv2
.
resize
(
im
,
(
227
,
227
))
im
=
np
.
reshape
(
im
,
(
1
,
227
,
227
,
3
))
outputs
=
predict_func
([
im
,
(
1
,)])[
0
]
prob
=
outputs
[
0
]
print
prob
.
shape
print
prob
.
argsort
()[
-
10
:][::
-
1
]
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
#start_train(get_config())
# run alexnet with given model (in npy format)
run_test
(
'alexnet.npy'
)
example_cifar10.py
View file @
9bb0b8f6
...
...
@@ -8,6 +8,7 @@ import argparse
import
numpy
as
np
import
os
from
tensorpack.train
import
TrainConfig
,
start_train
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.symbolic_functions
import
*
...
...
@@ -102,18 +103,17 @@ def get_config():
input_queue
=
tf
.
RandomShuffleQueue
(
100
,
50
,
[
x
.
dtype
for
x
in
input_vars
],
name
=
'queue'
)
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-4
,
global_step
=
g
lobal_step_var
,
global_step
=
g
et_global_step_var
()
,
decay_steps
=
dataset_train
.
size
()
*
50
,
decay_rate
=
0.1
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
dict
(
return
TrainConfig
(
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callback
=
Callbacks
([
callback
s
=
Callbacks
([
SummaryWriter
(),
PeriodicSaver
(),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
...
...
@@ -127,7 +127,6 @@ def get_config():
)
if
__name__
==
'__main__'
:
from
tensorpack
import
train
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
...
...
@@ -136,9 +135,8 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
train
.
prepare
()
config
=
get_config
()
if
args
.
load
:
config
[
'session_init'
]
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
train
.
start_train
(
config
)
start_train
(
config
)
example_mnist.py
View file @
9bb0b8f6
...
...
@@ -10,6 +10,7 @@ import numpy as np
import
os
,
sys
import
argparse
from
tensorpack.train
import
TrainConfig
,
start_train
from
tensorpack.models
import
*
from
tensorpack.utils
import
*
from
tensorpack.utils.symbolic_functions
import
*
...
...
@@ -97,7 +98,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
#step_per_epoch = 20
step_per_epoch
=
2
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config
=
get_default_sess_config
()
...
...
@@ -115,18 +116,17 @@ def get_config():
#input_queue = tf.FIFOQueue(
#100, [x.dtype for x in input_vars], name='queue')
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-4
,
global_step
=
g
lobal_step_var
,
global_step
=
g
et_global_step_var
()
,
decay_steps
=
dataset_train
.
size
()
*
50
,
decay_rate
=
0.1
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
dict
(
return
TrainConfig
(
dataset
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callback
=
Callbacks
([
callback
s
=
Callbacks
([
SummaryWriter
(),
PeriodicSaver
(),
ValidationError
(
dataset_test
,
prefix
=
'test'
),
...
...
@@ -140,7 +140,6 @@ def get_config():
)
if
__name__
==
'__main__'
:
from
tensorpack
import
train
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
...
...
@@ -149,8 +148,7 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
train
.
prepare
()
config
=
get_config
()
if
args
.
load
:
config
[
'session_init'
]
=
SaverRestore
(
args
.
load
)
train
.
start_train
(
config
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
start_train
(
config
)
tensorpack/predict.py
View file @
9bb0b8f6
...
...
@@ -13,47 +13,73 @@ from utils.modelutils import describe_model
from
utils
import
logger
from
dataflow
import
DataFlow
,
BatchData
class
PredictConfig
(
object
):
def
__init__
(
self
,
**
kwargs
):
"""
The config used by `get_predict_func`
Args:
session_config: a tf.ConfigProto instance to instantiate the
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session.
inputs: a list of input variables. must match the dataset later
used for prediction.
get_model_func: a function taking `inputs` and `is_training` and
return a tuple of output list as well as the cost to minimize
output_var_names: a list of names of the output variable to predict, the
variables can be any computable tensor in the graph.
if None, will predict everything returned by `get_model_func`
(all outputs as well as the cost). Predict only specific output
might be faster and might require only some of the input variables.
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
inputs
=
kwargs
.
pop
(
'inputs'
)
[
assert_type
(
i
,
tf
.
Tensor
)
for
i
in
self
.
inputs
]
self
.
get_model_func
=
kwargs
.
pop
(
'get_model_func'
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
,
None
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_predict_func
(
config
):
"""
Args:
config: a
tensorpack config dictionary
config: a
PredictConfig
Returns:
a function that takes a list of inputs to run the model
A prediction function that takes a list of inputs value, and return
one/a list of output values.
If `output_var_names` is set, then the prediction function will
return a list of output values. If not, will return a list of output
values and a cost.
"""
sess_config
=
config
.
get
(
'session_config'
,
None
)
if
sess_config
is
None
:
sess_config
=
get_default_sess_config
()
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
sess_init
=
config
[
'session_init'
]
# Provide this if only specific output is needed.
# by default will evaluate all outputs as well as cost
output_var_name
=
config
.
get
(
'output_var'
,
None
)
output_var_names
=
config
.
output_var_names
# input/output variables
input_vars
=
config
[
'inputs'
]
get_model_func
=
config
[
'get_model_func'
]
output_vars
,
cost_var
=
get_model_func
(
input_vars
,
is_training
=
False
)
input_vars
=
config
.
inputs
output_vars
,
cost_var
=
config
.
get_model_func
(
input_vars
,
is_training
=
False
)
# check output_var_names against output_vars
if
output_var_names
is
not
None
:
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
n
)
for
n
in
output_var_names
]
describe_model
()
sess
=
tf
.
Session
(
config
=
sess
_config
)
sess
_init
.
init
(
sess
)
sess
=
tf
.
Session
(
config
=
config
.
session
_config
)
config
.
session
_init
.
init
(
sess
)
def
run_input
(
dp
):
# TODO if input and dp not aligned?
feed
=
dict
(
zip
(
input_vars
,
dp
))
if
output_var_name
is
not
None
:
fetches
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
output_var_name
)
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
return
results
[
0
]
if
output_var_names
is
not
None
:
results
=
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
results
else
:
fetches
=
[
cost_var
]
+
output_vars
results
=
sess
.
run
(
fetches
,
feed_dict
=
feed
)
results
=
sess
.
run
([
cost_var
]
+
output_vars
,
feed_dict
=
feed
)
cost
=
results
[
0
]
outputs
=
results
[
1
:]
return
cost
,
outputs
return
outputs
,
cost
return
run_input
class
DatasetPredictor
(
object
):
...
...
tensorpack/train.py
View file @
9bb0b8f6
...
...
@@ -9,14 +9,60 @@ import argparse
from
utils
import
*
from
utils.concurrency
import
EnqueueThread
,
coordinator_guard
from
utils.callback
import
Callbacks
from
utils.summary
import
summary_moving_average
from
utils.modelutils
import
describe_model
from
utils
import
logger
from
dataflow
import
DataFlow
def
prepare
():
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
class
TrainConfig
(
object
):
""" config for training"""
def
__init__
(
self
,
**
kwargs
):
"""
Args:
dataset: the dataset to train. a tensorpack.dataflow.DataFlow instance.
optimizer: a tf.train.Optimizer instance defining the optimizer
for trainig. default to an AdamOptimizer
callbacks: a tensorpack.utils.callback.Callbacks instance. Define
the callbacks to perform during training. has to contain a
SummaryWriter and a PeriodicSaver
session_config: a tf.ConfigProto instance to instantiate the
session. default to a session running 1 GPU.
session_init: a tensorpack.utils.sessinit.SessionInit instance to
initialize variables of a session. default to a new session.
inputs: a list of input variables. must match what is returned by
the dataset
input_queue: the queue used for input. default to a FIFO queue
with capacity 5
get_model_func: a function taking `inputs` and `is_training` and
return a tuple of output list as well as the cost to minimize
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
dataset
=
kwargs
.
pop
(
'dataset'
)
assert_type
(
self
.
dataset
,
DataFlow
)
self
.
optimizer
=
kwargs
.
pop
(
'optimizer'
,
tf
.
train
.
AdamOptimizer
())
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
self
.
callbacks
=
kwargs
.
pop
(
'callbacks'
)
assert_type
(
self
.
callbacks
,
Callbacks
)
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
())
assert_type
(
self
.
session_config
,
tf
.
ConfigProto
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
NewSession
())
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
inputs
=
kwargs
.
pop
(
'inputs'
)
[
assert_type
(
i
,
tf
.
Tensor
)
for
i
in
self
.
inputs
]
self
.
input_queue
=
kwargs
.
pop
(
'input_queue'
,
tf
.
FIFOQueue
(
5
,
[
x
.
dtype
for
x
in
self
.
inputs
],
name
=
'input_queue'
))
assert_type
(
self
.
input_queue
,
tf
.
QueueBase
)
assert
self
.
input_queue
.
dtypes
==
[
x
.
dtype
for
x
in
self
.
inputs
]
self
.
get_model_func
=
kwargs
.
pop
(
'get_model_func'
)
self
.
step_per_epoch
=
int
(
kwargs
.
pop
(
'step_per_epoch'
,
self
.
dataset
.
size
()))
self
.
max_epoch
=
int
(
kwargs
.
pop
(
'max_epoch'
,
100
))
assert
self
.
step_per_epoch
>
0
and
self
.
max_epoch
>
0
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_train_op
(
optimizer
,
cost_var
):
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
...
...
@@ -37,57 +83,36 @@ def start_train(config):
"""
Start training with the given config
Args:
config: a
tensorpack config dictionary
config: a
TrainConfig instance
"""
dataset
=
config
[
'dataset'
]
assert
isinstance
(
dataset
,
DataFlow
),
dataset
.
__class__
# a tf.train.Optimizer instance
optimizer
=
config
[
'optimizer'
]
assert
isinstance
(
optimizer
,
tf
.
train
.
Optimizer
),
optimizer
.
__class__
# a list of Callback instance
callbacks
=
config
[
'callback'
]
# a tf.ConfigProto instance
sess_config
=
config
.
get
(
'session_config'
,
None
)
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
sess_init
=
config
.
get
(
'session_init'
,
NewSession
())
# input/output variables
input_vars
=
config
[
'inputs'
]
input_queue
=
config
[
'input_queue'
]
get_model_func
=
config
[
'get_model_func'
]
step_per_epoch
=
int
(
config
[
'step_per_epoch'
])
max_epoch
=
int
(
config
[
'max_epoch'
])
assert
step_per_epoch
>
0
and
max_epoch
>
0
input_vars
=
config
.
inputs
input_queue
=
config
.
input_queue
callbacks
=
config
.
callbacks
enqueue_op
=
input_queue
.
enqueue
(
tuple
(
input_vars
))
model_inputs
=
input_queue
.
dequeue
()
# set dequeue shape
for
qv
,
v
in
zip
(
model_inputs
,
input_vars
):
qv
.
set_shape
(
v
.
get_shape
())
output_vars
,
cost_var
=
get_model_func
(
model_inputs
,
is_training
=
True
)
output_vars
,
cost_var
=
config
.
get_model_func
(
model_inputs
,
is_training
=
True
)
# build graph
tf
.
add_to_collection
(
FORWARD_FUNC_KEY
,
get_model_func
)
tf
.
add_to_collection
(
FORWARD_FUNC_KEY
,
config
.
get_model_func
)
for
v
in
input_vars
:
tf
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
tf
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
describe_model
()
train_op
=
get_train_op
(
optimizer
,
cost_var
)
train_op
=
get_train_op
(
config
.
optimizer
,
cost_var
)
sess
=
tf
.
Session
(
config
=
sess
_config
)
sess
_init
.
init
(
sess
)
sess
=
tf
.
Session
(
config
=
config
.
session
_config
)
config
.
session
_init
.
init
(
sess
)
# start training:
coord
=
tf
.
train
.
Coordinator
()
# a thread that keeps filling the queue
input_th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
dataset
,
input_queue
)
input_th
=
EnqueueThread
(
sess
,
coord
,
enqueue_op
,
config
.
dataset
,
input_queue
)
model_th
=
tf
.
train
.
start_queue_runners
(
sess
=
sess
,
coord
=
coord
,
daemon
=
True
,
start
=
True
)
input_th
.
start
()
...
...
@@ -95,15 +120,13 @@ def start_train(config):
with
sess
.
as_default
(),
\
coordinator_guard
(
sess
,
coord
):
callbacks
.
before_train
()
for
epoch
in
xrange
(
1
,
max_epoch
):
for
epoch
in
xrange
(
1
,
config
.
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
for
step
in
xrange
(
step_per_epoch
):
for
step
in
xrange
(
config
.
step_per_epoch
):
if
coord
.
should_stop
():
return
fetches
=
[
train_op
,
cost_var
]
+
output_vars
+
model_inputs
print
'before'
results
=
sess
.
run
(
fetches
)
print
'after'
cost
=
results
[
1
]
outputs
=
results
[
2
:
2
+
len
(
output_vars
)]
inputs
=
results
[
-
len
(
model_inputs
):]
...
...
@@ -111,4 +134,3 @@ def start_train(config):
# note that summary_op will take a data from the queue.
callbacks
.
trigger_epoch
()
tensorpack/utils/__init__.py
View file @
9bb0b8f6
...
...
@@ -11,7 +11,6 @@ from contextlib import contextmanager
import
tensorflow
as
tf
import
collections
import
logger
def
global_import
(
name
):
...
...
@@ -98,3 +97,9 @@ class memoized(object):
def
__get__
(
self
,
obj
,
objtype
):
'''Support instance methods.'''
return
functools
.
partial
(
self
.
__call__
,
obj
)
@
memoized
def
get_global_step_var
():
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
return
global_step_var
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment