Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
77c8bde9
Commit
77c8bde9
authored
Feb 13, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
trainer.build_train_tower(). and some doc fix.
parent
a313662f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
41 additions
and
27 deletions
+41
-27
docs/conf.py
docs/conf.py
+10
-12
docs/requirements.txt
docs/requirements.txt
+1
-0
examples/GAN/GAN.py
examples/GAN/GAN.py
+1
-3
tensorpack/tfutils/optimizer.py
tensorpack/tfutils/optimizer.py
+5
-2
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+4
-0
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+19
-8
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-2
No files found.
docs/conf.py
View file @
77c8bde9
...
...
@@ -23,19 +23,17 @@ os.environ['TENSORPACK_DOC_BUILDING'] = '1'
MOCK_MODULES
=
[
'scipy'
,
'tensorflow'
,
'tensorflow.contrib'
,
'tensorflow.python.ops'
,
'tensorflow.contrib.framework'
,
'tensorflow.models'
,
'tensorflow.models.rnn'
,
'tensorflow.models.rnn.ptb'
,
'tensorflow.python'
,
'tensorflow.python.training'
,
'sklearn.datasets'
,
#'tensorflow', 'tensorflow.contrib',
#'tensorflow.python.ops',
#'tensorflow.contrib.framework',
#'tensorflow.python',
#'tensorflow.python.training',
'sklearn.datasets'
,
'sklearn'
,
'scipy.misc'
,
'h5py'
,
'nltk'
,
'cv2'
,
'scipy.io'
,
'dill'
,
'zmq'
,
'subprocess32'
,
'lmdb'
,
'tornado.concurrent'
,
'tornado'
,
'msgpack'
,
'msgpack_numpy'
,
'ale_python_interface'
,
'sklearn'
,
'functools32'
]
'cv2'
,
'scipy.io'
,
'dill'
,
'zmq'
,
'subprocess32'
,
'lmdb'
,
'tornado.concurrent'
,
'tornado'
,
'msgpack'
,
'msgpack_numpy'
,
'gym'
,
'functools32'
]
for
mod_name
in
MOCK_MODULES
:
sys
.
modules
[
mod_name
]
=
mock
.
Mock
(
name
=
mod_name
)
...
...
docs/requirements.txt
View file @
77c8bde9
...
...
@@ -2,5 +2,6 @@ termcolor
numpy
tqdm
decorator
tensorflow
Sphinx==1.5.1
recommonmark==0.4.0
examples/GAN/GAN.py
View file @
77c8bde9
...
...
@@ -73,9 +73,7 @@ class GANTrainer(FeedfreeTrainerBase):
def
_setup
(
self
):
super
(
GANTrainer
,
self
)
.
_setup
()
with
TowerContext
(
''
):
actual_inputs
=
self
.
_get_input_tensors
()
self
.
model
.
build_graph
(
actual_inputs
)
self
.
build_train_tower
()
opt
=
self
.
model
.
get_optimizer
()
self
.
g_min
=
opt
.
minimize
(
self
.
model
.
g_loss
,
var_list
=
self
.
model
.
g_vars
,
name
=
'g_op'
)
...
...
tensorpack/tfutils/optimizer.py
View file @
77c8bde9
...
...
@@ -11,6 +11,9 @@ __all__ = ['apply_grad_processors', 'ProxyOptimizer',
class
ProxyOptimizer
(
tf
.
train
.
Optimizer
):
"""
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
"""
def
__init__
(
self
,
opt
):
self
.
_opt
=
opt
...
...
@@ -54,8 +57,8 @@ def apply_grad_processors(opt, gradprocs):
class
PostProcessVariablesOptimizer
(
ProxyOptimizer
):
"""
An optimizer which applies an operation to variables
(e.g. clipping,
quantization) after updating the gradient.
An optimizer which applies an operation to variables
(e.g. clipping,
quantization) after updating the gradient.
"""
def
__init__
(
self
,
opt
,
func
,
colocate
=
True
):
"""
...
...
tensorpack/tfutils/tower.py
View file @
77c8bde9
...
...
@@ -102,6 +102,10 @@ class TowerContext(object):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
def
__str__
(
self
):
return
"TowerContext(name={}, is_training={})"
.
format
(
self
.
_name
,
self
.
_is_training
)
def
get_current_tower_context
():
global
_CurrentTowerContext
...
...
tensorpack/train/feedfree.py
View file @
77c8bde9
...
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
from
..utils
import
log_deprecated
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
,
get_current_tower_context
from
.input_data
import
QueueInput
,
FeedfreeInput
from
.base
import
Trainer
...
...
@@ -27,8 +27,20 @@ class FeedfreeTrainerBase(Trainer):
summary_str
=
self
.
summary_op
.
eval
()
self
.
add_summary
(
summary_str
)
def
_get_input_tensors
(
self
):
return
self
.
_input_method
.
get_input_tensors
()
def
build_train_tower
(
self
):
"""
Get input tensors from `self.input_method` and build the graph.
"""
def
f
():
inputs
=
self
.
_input_method
.
get_input_tensors
()
self
.
model
.
build_graph
(
inputs
)
ctx
=
get_current_tower_context
()
if
ctx
is
None
:
with
TowerContext
(
''
):
f
()
else
:
assert
ctx
.
is_training
,
ctx
f
()
def
_setup
(
self
):
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
...
...
@@ -39,16 +51,15 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient"""
actual_inputs
=
self
.
_get_input_tensors
()
self
.
model
.
build_graph
(
actual_inputs
)
cost_var
=
self
.
model
.
get_cost
()
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
opt
=
self
.
config
.
optimizer
# GATE_NONE faster?
grads
=
opt
.
compute_gradients
(
cost
_var
,
cost
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
True
)
return
cost
_var
,
grads
return
cost
,
grads
def
run_step
(
self
):
""" Simply run ``self.train_op``, which minimizes the cost."""
...
...
tensorpack/train/multigpu.py
View file @
77c8bde9
...
...
@@ -141,8 +141,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
# grads = grad_list[0]
else
:
def
get_cost
():
actual_inputs
=
self
.
_get_input_tensors
()
self
.
model
.
build_graph
(
actual_inputs
)
self
.
build_train_tower
()
return
self
.
model
.
get_cost
()
cost_list
=
MultiGPUTrainer
.
_multi_tower_costs
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment