Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
26edfabe
Commit
26edfabe
authored
Jan 01, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
misc fix
parent
f698a04d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
33 additions
and
19 deletions
+33
-19
example_mnist.py
example_mnist.py
+1
-1
tensorpack/train.py
tensorpack/train.py
+1
-0
tensorpack/utils/__init__.py
tensorpack/utils/__init__.py
+16
-4
tensorpack/utils/callback.py
tensorpack/utils/callback.py
+8
-7
tensorpack/utils/summary.py
tensorpack/utils/summary.py
+1
-1
tensorpack/utils/symbolic_functions.py
tensorpack/utils/symbolic_functions.py
+2
-2
tensorpack/utils/validation_callback.py
tensorpack/utils/validation_callback.py
+4
-4
No files found.
example_mnist.py
View file @
26edfabe
...
...
@@ -98,7 +98,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
#
step_per_epoch = 30
step_per_epoch
=
30
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config
=
get_default_sess_config
()
...
...
tensorpack/train.py
View file @
26edfabe
...
...
@@ -130,6 +130,7 @@ def start_train(config):
with
sess
.
as_default
(),
\
coordinator_guard
(
sess
,
coord
):
logger
.
info
(
"Start with global_step={}"
.
format
(
get_global_step
()))
callbacks
.
before_train
()
for
epoch
in
xrange
(
1
,
config
.
max_epoch
):
with
timed_operation
(
'epoch {}'
.
format
(
epoch
)):
...
...
tensorpack/utils/__init__.py
View file @
26edfabe
...
...
@@ -37,6 +37,9 @@ def create_test_graph():
input_vars_train
=
G
.
get_collection
(
INPUT_VARS_KEY
)
forward_func
=
G
.
get_collection
(
FORWARD_FUNC_KEY
)[
0
]
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
# create a global step var in test graph
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
input_vars
=
[]
for
v
in
input_vars_train
:
name
=
v
.
name
...
...
@@ -99,11 +102,20 @@ class memoized(object):
'''Support instance methods.'''
return
functools
.
partial
(
self
.
__call__
,
obj
)
@
memoized
def
get_global_step_var
():
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
return
global_step_var
""" get global_step variable in the current graph"""
try
:
return
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
except
KeyError
:
var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
return
var
def
get_global_step
():
""" get global_step value with current graph and session"""
return
tf
.
train
.
global_step
(
tf
.
get_default_session
(),
get_global_step_var
())
def
get_rng
(
self
):
return
np
.
random
.
RandomState
()
tensorpack/utils/callback.py
View file @
26edfabe
...
...
@@ -10,7 +10,7 @@ import os
import
time
from
abc
import
abstractmethod
,
ABCMeta
from
.
import
create_test_session
from
.
import
create_test_session
,
get_global_step
from
.naming
import
*
import
logger
...
...
@@ -53,6 +53,7 @@ class PeriodicCallback(Callback):
def
trigger_epoch
(
self
):
self
.
epoch_num
+=
1
if
self
.
epoch_num
%
self
.
__period
==
0
:
self
.
global_step
=
get_global_step
()
self
.
_trigger
()
@
abstractmethod
...
...
@@ -72,13 +73,14 @@ class PeriodicSaver(PeriodicCallback):
keep_checkpoint_every_n_hours
=
self
.
keep_freq
)
def
_trigger
(
self
):
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
epoch_num
)
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
global_step
)
class
SummaryWriter
(
Callback
):
def
__init__
(
self
):
self
.
log_dir
=
logger
.
LOG_DIR
self
.
epoch_num
=
0
def
_before_train
(
self
):
self
.
writer
=
tf
.
train
.
SummaryWriter
(
...
...
@@ -91,9 +93,7 @@ class SummaryWriter(Callback):
if
self
.
summary_op
is
None
:
return
summary_str
=
self
.
summary_op
.
eval
()
self
.
epoch_num
+=
1
self
.
writer
.
add_summary
(
summary_str
,
self
.
epoch_num
)
self
.
writer
.
add_summary
(
summary_str
,
get_global_step
())
class
CallbackTimeLogger
(
object
):
def
__init__
(
self
):
...
...
@@ -214,5 +214,6 @@ class Callbacks(Callback):
def
trigger_epoch
(
self
):
self
.
train
.
trigger_epoch
()
# TODO test callbacks can be run async?
self
.
test
.
trigger_epoch
()
tensorpack/utils/summary.py
View file @
26edfabe
...
...
@@ -51,7 +51,7 @@ def summary_moving_average(cost_var):
"""
global_step_var
=
tf
.
get_default_graph
()
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
averager
=
tf
.
train
.
ExponentialMovingAverage
(
0.9
,
num_updates
=
global_step_var
,
name
=
'
avg
'
)
0.9
,
num_updates
=
global_step_var
,
name
=
'
moving_averages
'
)
vars_to_summary
=
[
cost_var
]
+
\
tf
.
get_collection
(
SUMMARY_VARS_KEY
)
+
\
tf
.
get_collection
(
COST_VARS_KEY
)
...
...
tensorpack/utils/symbolic_functions.py
View file @
26edfabe
...
...
@@ -8,7 +8,7 @@ import numpy as np
__all__
=
[
'one_hot'
,
'batch_flatten'
,
'logSoftmax'
]
def
one_hot
(
y
,
num_labels
):
with
tf
.
variable_scope
(
'one_hot'
):
with
tf
.
op_scope
([
y
,
num_labels
],
'one_hot'
):
batch_size
=
tf
.
size
(
y
)
y
=
tf
.
expand_dims
(
y
,
1
)
indices
=
tf
.
expand_dims
(
tf
.
range
(
0
,
batch_size
),
1
)
...
...
@@ -23,7 +23,7 @@ def batch_flatten(x):
return
tf
.
reshape
(
x
,
[
-
1
,
total_dim
])
def
logSoftmax
(
x
):
with
tf
.
variable_scope
(
'logSoftmax'
):
with
tf
.
op_scope
([
x
],
'logSoftmax'
):
z
=
x
-
tf
.
reduce_max
(
x
,
1
,
keep_dims
=
True
)
logprob
=
z
-
tf
.
log
(
tf
.
reduce_sum
(
tf
.
exp
(
z
),
1
,
keep_dims
=
True
))
return
logprob
...
...
tensorpack/utils/validation_callback.py
View file @
26edfabe
...
...
@@ -58,11 +58,11 @@ class ValidationError(PeriodicCallback):
self
.
writer
.
add_summary
(
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
err_stat
.
accuracy
),
self
.
epoch_num
)
self
.
global_step
)
self
.
writer
.
add_summary
(
create_summary
(
'{}_cost'
.
format
(
self
.
prefix
),
cost_avg
),
self
.
epoch_num
)
self
.
global_step
)
logger
.
info
(
"{} validation after epoch
{}: err={:.4f}, cost={:.3f}"
.
format
(
self
.
prefix
,
self
.
epoch_num
,
err_stat
.
accuracy
,
cost_avg
))
"{} validation after epoch
{},step
{}: err={:.4f}, cost={:.3f}"
.
format
(
self
.
prefix
,
self
.
epoch_num
,
self
.
global_step
,
err_stat
.
accuracy
,
cost_avg
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment