Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ae985fc4
Commit
ae985fc4
authored
Feb 21, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ugly fix of MODEL_KEY
parent
24f898ec
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
28 additions
and
29 deletions
+28
-29
example_mnist.py
example_mnist.py
+1
-0
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+15
-18
tensorpack/callbacks/validation_callback.py
tensorpack/callbacks/validation_callback.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+4
-4
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+5
-4
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+1
-1
No files found.
example_mnist.py
View file @
ae985fc4
...
@@ -92,6 +92,7 @@ def get_config():
...
@@ -92,6 +92,7 @@ def get_config():
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
dataset_test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
dataset_train
.
size
()
step_per_epoch
=
20
# prepare session
# prepare session
sess_config
=
get_default_sess_config
()
sess_config
=
get_default_sess_config
()
...
...
tensorpack/callbacks/group.py
View file @
ae985fc4
...
@@ -13,22 +13,22 @@ from ..utils import *
...
@@ -13,22 +13,22 @@ from ..utils import *
__all__
=
[
'Callbacks'
]
__all__
=
[
'Callbacks'
]
@
contextmanager
@
contextmanager
def
create_test_graph
():
def
create_test_graph
(
trainer
):
G
=
tf
.
get_default_graph
()
model
=
trainer
.
model
.
__class__
()
model
=
G
.
get_collection
(
MODEL_KEY
)[
0
]
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
with
tf
.
Graph
()
.
as_default
()
as
Gtest
:
# create a global step var in test graph
# create a global step var in test graph
global_step_var
=
tf
.
Variable
(
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
new_model
=
model
.
__class__
()
input_vars
=
model
.
get_input_vars
()
input_vars
=
new_model
.
get_input_vars
()
for
v
in
input_vars
:
cost
=
new_model
.
get_cost
(
input_vars
,
is_training
=
False
)
tf
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
Gtest
.
add_to_collection
(
MODEL_KEY
,
new_model
)
cost
=
model
.
get_cost
(
input_vars
,
is_training
=
False
)
yield
Gtest
yield
Gtest
@
contextmanager
@
contextmanager
def
create_test_session
():
def
create_test_session
(
trainer
):
with
create_test_graph
():
""" create a test-time session from trainer"""
with
create_test_graph
(
trainer
):
with
tf
.
Session
()
as
sess
:
with
tf
.
Session
()
as
sess
:
yield
sess
yield
sess
...
@@ -66,16 +66,13 @@ class TestCallbackContext(object):
...
@@ -66,16 +66,13 @@ class TestCallbackContext(object):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
sess
=
None
self
.
sess
=
None
def
_init_test_sess
(
self
):
with
create_test_session
()
as
sess
:
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
self
.
saver
=
tf
.
train
.
Saver
()
@
contextmanager
@
contextmanager
def
before_train_context
(
self
):
def
before_train_context
(
self
,
trainer
):
if
self
.
sess
is
None
:
if
self
.
sess
is
None
:
self
.
_init_test_sess
()
with
create_test_session
(
trainer
)
as
sess
:
self
.
sess
=
sess
self
.
graph
=
sess
.
graph
self
.
saver
=
tf
.
train
.
Saver
()
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
with
self
.
graph
.
as_default
(),
self
.
sess
.
as_default
():
yield
yield
...
@@ -112,7 +109,7 @@ class Callbacks(Callback):
...
@@ -112,7 +109,7 @@ class Callbacks(Callback):
if
isinstance
(
cb
.
type
,
TrainCallback
):
if
isinstance
(
cb
.
type
,
TrainCallback
):
cb
.
before_train
(
self
.
trainer
)
cb
.
before_train
(
self
.
trainer
)
else
:
else
:
with
self
.
test_callback_context
.
before_train_context
():
with
self
.
test_callback_context
.
before_train_context
(
self
.
trainer
):
cb
.
before_train
(
self
.
trainer
)
cb
.
before_train
(
self
.
trainer
)
def
_after_train
(
self
):
def
_after_train
(
self
):
...
...
tensorpack/callbacks/validation_callback.py
View file @
ae985fc4
...
@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback):
...
@@ -26,7 +26,7 @@ class ValidationCallback(PeriodicCallback):
self
.
cost_var_name
=
cost_var_name
self
.
cost_var_name
=
cost_var_name
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
input_vars
=
tf
.
get_collection
(
MODEL_KEY
)[
0
]
.
get_input_vars
(
)
self
.
input_vars
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
_find_output_vars
()
self
.
_find_output_vars
()
...
...
tensorpack/train/base.py
View file @
ae985fc4
...
@@ -24,7 +24,7 @@ class Trainer(object):
...
@@ -24,7 +24,7 @@ class Trainer(object):
"""
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
config
=
config
self
.
config
=
config
tf
.
add_to_collection
(
MODEL_KEY
,
config
.
model
)
self
.
model
=
config
.
model
@
abstractmethod
@
abstractmethod
def
train
(
self
):
def
train
(
self
):
...
...
tensorpack/train/train.py
→
tensorpack/train/train
er
.py
View file @
ae985fc4
#!/usr/bin/env python2
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: train.py
# File: train
er
.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -47,7 +47,7 @@ class SimpleTrainer(Trainer):
...
@@ -47,7 +47,7 @@ class SimpleTrainer(Trainer):
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
def
train
(
self
):
def
train
(
self
):
model
=
self
.
config
.
model
model
=
self
.
model
input_vars
=
model
.
get_input_vars
()
input_vars
=
model
.
get_input_vars
()
self
.
input_vars
=
input_vars
self
.
input_vars
=
input_vars
cost_var
=
model
.
get_cost
(
input_vars
,
is_training
=
True
)
cost_var
=
model
.
get_cost
(
input_vars
,
is_training
=
True
)
...
@@ -91,7 +91,7 @@ class QueueInputTrainer(Trainer):
...
@@ -91,7 +91,7 @@ class QueueInputTrainer(Trainer):
return
ret
return
ret
def
train
(
self
):
def
train
(
self
):
model
=
self
.
config
.
model
model
=
self
.
model
input_vars
=
model
.
get_input_vars
()
input_vars
=
model
.
get_input_vars
()
input_queue
=
model
.
get_input_queue
()
input_queue
=
model
.
get_input_queue
()
...
@@ -144,7 +144,7 @@ class QueueInputTrainer(Trainer):
...
@@ -144,7 +144,7 @@ class QueueInputTrainer(Trainer):
self
.
init_session_and_coord
()
self
.
init_session_and_coord
()
# create a thread that keeps filling the queue
# create a thread that keeps filling the queue
input_th
=
EnqueueThread
(
self
.
sess
,
self
.
coord
,
enqueue_op
,
self
.
config
.
dataset
,
input_queue
)
input_th
=
EnqueueThread
(
self
,
enqueue_op
,
self
.
config
.
dataset
,
input_queue
)
input_th
.
start
()
input_th
.
start
()
self
.
main_loop
()
self
.
main_loop
()
...
...
tensorpack/utils/concurrency.py
View file @
ae985fc4
...
@@ -23,11 +23,12 @@ class StoppableThread(threading.Thread):
...
@@ -23,11 +23,12 @@ class StoppableThread(threading.Thread):
class
EnqueueThread
(
threading
.
Thread
):
class
EnqueueThread
(
threading
.
Thread
):
def
__init__
(
self
,
sess
,
coord
,
enqueue_op
,
dataflow
,
queue
):
def
__init__
(
self
,
trainer
,
enqueue_op
,
dataflow
,
queue
):
super
(
EnqueueThread
,
self
)
.
__init__
()
super
(
EnqueueThread
,
self
)
.
__init__
()
self
.
sess
=
sess
self
.
sess
=
trainer
.
sess
self
.
coord
=
coord
self
.
coord
=
trainer
.
coord
self
.
input_vars
=
sess
.
graph
.
get_collection
(
MODEL_KEY
)[
0
]
.
get_input_vars
()
self
.
input_vars
=
trainer
.
model
.
get_input_vars
()
self
.
dataflow
=
dataflow
self
.
dataflow
=
dataflow
self
.
op
=
enqueue_op
self
.
op
=
enqueue_op
self
.
queue
=
queue
self
.
queue
=
queue
...
...
tensorpack/utils/naming.py
View file @
ae985fc4
...
@@ -8,7 +8,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
...
@@ -8,7 +8,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
# extra variables to summarize during training in a moving-average way
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
MODEL_KEY
=
'MODEL
'
INPUT_VARS_KEY
=
'INPUT_VARS
'
# export all upper case variables
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
all_local_names
=
locals
()
.
keys
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment