Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bbe4faf4
Commit
bbe4faf4
authored
Dec 26, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use train_config
parent
341a5d43
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
69 deletions
+123
-69
example_mnist.py
example_mnist.py
+30
-56
train.py
train.py
+63
-0
utils/callback.py
utils/callback.py
+30
-13
No files found.
example_mnist.py
View file @
bbe4faf4
...
...
@@ -11,8 +11,6 @@ sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages'))
import
tensorflow
as
tf
import
numpy
as
np
from
itertools
import
count
from
layers
import
*
from
utils
import
*
...
...
@@ -20,23 +18,21 @@ from dataflow.dataset import Mnist
from
dataflow
import
*
IMAGE_SIZE
=
28
NUM_CLASS
=
10
batch_size
=
128
LOG_DIR
=
'train_log'
def
get_model
(
inputs
):
"""
Args:
inputs: a list of input variable,
e.g.: [input
, label
] with:
input: bx28x28
label: bx1 integer
e.g.: [input
_var, label_var
] with:
input
_var
: bx28x28
label
_var
: bx1 integer
Returns:
(outputs, cost)
outputs: a list of output variable
cost: scalar variable
"""
# use this
dropout variable! it will be se
t to 1 at test time
# use this
variable in dropout! Tensorpack will automatically set i
t to 1 at test time
keep_prob
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
tuple
(),
name
=
DROPOUT_PROB_OP_NAME
)
input
,
label
=
inputs
...
...
@@ -62,7 +58,7 @@ def get_model(inputs):
fc1
=
FullyConnected
(
'lr'
,
fc0
,
out_dim
=
10
)
prob
=
tf
.
nn
.
softmax
(
fc1
,
name
=
'output'
)
y
=
one_hot
(
label
,
NUM_CLASS
)
y
=
one_hot
(
label
,
10
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
fc1
,
y
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cost'
)
...
...
@@ -74,59 +70,37 @@ def get_model(inputs):
return
[
prob
,
correct
],
cost
def
main
():
dataset_train
=
BatchData
(
Mnist
(
'train'
),
batch_size
)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
batch_size
,
remainder
=
True
)
callbacks
=
[
SummaryWriter
(
LOG_DIR
),
AccuracyValidation
(
dataset_test
,
prefix
=
'test'
,
period
=
1
),
TrainingAccuracy
(),
PeriodicSaver
(
LOG_DIR
,
period
=
1
)
]
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-4
)
with
tf
.
Graph
()
.
as_default
():
dataset_train
=
BatchData
(
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
Mnist
(
'test'
),
128
,
remainder
=
True
)
sess_config
=
tf
.
ConfigProto
()
sess_config
.
device_count
[
'GPU'
]
=
1
with
tf
.
Graph
()
.
as_default
():
G
=
tf
.
get_default_graph
()
# prepare model
image_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
label_var
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
input_vars
=
[
image_var
,
label_var
]
for
v
in
input_vars
:
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
output_vars
,
cost_var
=
get_model
(
input_vars
)
for
v
in
output_vars
:
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
train_op
=
optimizer
.
minimize
(
cost_var
)
sess
=
tf
.
Session
(
config
=
sess_config
)
sess
.
run
(
tf
.
initialize_all_variables
())
with
sess
.
as_default
():
for
ext
in
callbacks
:
ext
.
before_train
()
keep_prob_var
=
G
.
get_tensor_by_name
(
DROPOUT_PROB_VAR_NAME
)
for
epoch
in
count
(
1
):
for
dp
in
dataset_train
.
get_data
():
feed
=
{
keep_prob_var
:
0.5
}
feed
.
update
(
dict
(
zip
(
input_vars
,
dp
)))
results
=
sess
.
run
(
[
train_op
,
cost_var
]
+
output_vars
,
feed_dict
=
feed
)
cost
=
results
[
1
]
outputs
=
results
[
2
:]
assert
len
(
outputs
)
==
len
(
output_vars
)
for
cb
in
callbacks
:
cb
.
trigger_step
(
dp
,
outputs
,
cost
)
for
cb
in
callbacks
:
cb
.
trigger_epoch
()
summary_writer
.
close
()
config
=
dict
(
dataset_train
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-4
),
callbacks
=
[
TrainingAccuracy
(),
AccuracyValidation
(
dataset_test
,
prefix
=
'test'
,
period
=
1
),
PeriodicSaver
(
LOG_DIR
,
period
=
1
),
SummaryWriter
(
LOG_DIR
),
],
session_config
=
sess_config
,
inputs
=
input_vars
,
outputs
=
output_vars
,
cost
=
cost_var
,
max_epoch
=
100
,
)
from
train
import
start_train
start_train
(
config
)
if
__name__
==
'__main__'
:
...
...
train.py
0 → 100644
View file @
bbe4faf4
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: train.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
from
utils
import
*
from
itertools
import
count
def
start_train
(
config
):
"""
Start training with the given config
Args:
config: a tensorpack config dictionary
"""
# a Dataflow instance
dataset_train
=
config
[
'dataset_train'
]
# a tf.train.Optimizer instance
optimizer
=
config
[
'optimizer'
]
# a list of Callback instance
callbacks
=
Callbacks
(
config
.
get
(
'callbacks'
,
[]))
# a tf.ConfigProto instance
sess_config
=
config
.
get
(
'session_config'
,
None
)
# a list of input/output variables
input_vars
=
config
[
'inputs'
]
output_vars
=
config
[
'outputs'
]
cost_var
=
config
[
'cost'
]
max_epoch
=
int
(
config
[
'max_epoch'
])
# build graph
G
=
tf
.
get_default_graph
()
for
v
in
input_vars
:
G
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
for
v
in
output_vars
:
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
train_op
=
optimizer
.
minimize
(
cost_var
)
sess
=
tf
.
Session
(
config
=
sess_config
)
# start training
with
sess
.
as_default
():
sess
.
run
(
tf
.
initialize_all_variables
())
callbacks
.
before_train
()
keep_prob_var
=
G
.
get_tensor_by_name
(
DROPOUT_PROB_VAR_NAME
)
for
epoch
in
xrange
(
1
,
max_epoch
):
for
dp
in
dataset_train
.
get_data
():
feed
=
{
keep_prob_var
:
0.5
}
feed
.
update
(
dict
(
zip
(
input_vars
,
dp
)))
results
=
sess
.
run
(
[
train_op
,
cost_var
]
+
output_vars
,
feed_dict
=
feed
)
cost
=
results
[
1
]
outputs
=
results
[
2
:]
callbacks
.
trigger_step
(
dp
,
outputs
,
cost
)
callbacks
.
trigger_epoch
()
utils/callback.py
View file @
bbe4faf4
...
...
@@ -8,6 +8,7 @@ import sys
import
numpy
as
np
import
os
from
abc
import
abstractmethod
from
.stat
import
*
from
.utils
import
*
from
.naming
import
*
...
...
@@ -20,12 +21,12 @@ class Callback(object):
def
_before_train
(
self
):
"""
Called before training
Called before
starting iterative
training
"""
# trigger after every step
def
trigger_step
(
self
,
dp
,
outputs
,
cost
):
"""
Callback to be triggered after every step (every backpropagation)
Args:
dp: the input dict fed into the graph
outputs: list of output values after running this dp
...
...
@@ -33,8 +34,10 @@ class Callback(object):
"""
pass
# trigger after every epoch
def
trigger_epoch
(
self
):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
pass
class
PeriodicCallback
(
Callback
):
...
...
@@ -77,11 +80,7 @@ class AccuracyValidation(PeriodicCallback):
self
.
dropout_var
=
self
.
get_tensor
(
DROPOUT_PROB_VAR_NAME
)
self
.
correct_var
=
self
.
get_tensor
(
self
.
correct_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
try
:
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
except
Exception
as
e
:
print
"SummaryWriter should be the first extension!"
raise
def
_trigger
(
self
):
cnt
=
0
...
...
@@ -121,11 +120,7 @@ class TrainingAccuracy(Callback):
self
.
epoch_num
=
0
def
_before_train
(
self
):
try
:
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
except
Exception
as
e
:
print
"SummaryWriter should be the first extension!"
raise
output_vars
=
self
.
graph
.
get_collection
(
OUTPUT_VARS_KEY
)
for
idx
,
var
in
enumerate
(
output_vars
):
if
var
.
name
==
self
.
correct_var_name
:
...
...
@@ -194,3 +189,25 @@ class SummaryWriter(Callback):
self
.
epoch_num
+=
1
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
class
Callbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
# put SummaryWriter to the first
for
idx
,
cb
in
enumerate
(
callbacks
):
if
type
(
cb
)
==
SummaryWriter
:
callbacks
.
insert
(
0
,
callbacks
.
pop
(
idx
))
break
self
.
callbacks
=
callbacks
def
before_train
(
self
):
for
cb
in
self
.
callbacks
:
cb
.
before_train
()
def
trigger_step
(
self
,
dp
,
outputs
,
cost
):
for
cb
in
self
.
callbacks
:
cb
.
trigger_step
(
dp
,
outputs
,
cost
)
def
trigger_epoch
(
self
):
for
cb
in
self
.
callbacks
:
cb
.
trigger_epoch
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment