Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dfacc168
Commit
dfacc168
authored
Oct 29, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
clean imports
parent
bcf8dbfe
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
122 additions
and
110 deletions
+122
-110
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+2
-1
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+0
-2
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+2
-2
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+3
-4
tensorpack/callbacks/stat.py
tensorpack/callbacks/stat.py
+1
-1
tensorpack/models/_common.py
tensorpack/models/_common.py
+3
-3
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+1
-1
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+1
-1
tensorpack/models/fc.py
tensorpack/models/fc.py
+2
-2
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+3
-75
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+1
-1
tensorpack/models/pool.py
tensorpack/models/pool.py
+4
-4
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+2
-2
tensorpack/predict/base.py
tensorpack/predict/base.py
+1
-2
tensorpack/predict/common.py
tensorpack/predict/common.py
+2
-1
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+0
-3
tensorpack/tfutils/__init__.py
tensorpack/tfutils/__init__.py
+1
-0
tensorpack/tfutils/modelutils.py
tensorpack/tfutils/modelutils.py
+2
-0
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+7
-0
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+81
-0
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+1
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-3
No files found.
tensorpack/RL/gymenv.py
View file @
dfacc168
...
...
@@ -11,8 +11,10 @@ try:
gym
.
undo_logger_setup
()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
__all__
=
[
'GymEnv'
]
except
ImportError
:
logger
.
warn
(
"Cannot import gym. GymEnv won't be available."
)
__all__
=
[]
import
threading
...
...
@@ -20,7 +22,6 @@ from ..utils.fs import *
from
..utils.stat
import
*
from
.envbase
import
RLEnvironment
,
DiscreteActionSpace
__all__
=
[
'GymEnv'
]
_ALE_LOCK
=
threading
.
Lock
()
...
...
tensorpack/callbacks/base.py
View file @
dfacc168
...
...
@@ -8,8 +8,6 @@ import os
import
time
from
abc
import
abstractmethod
,
ABCMeta
from
..utils
import
*
__all__
=
[
'Callback'
,
'PeriodicCallback'
]
class
Callback
(
object
):
...
...
tensorpack/callbacks/group.py
View file @
dfacc168
...
...
@@ -7,8 +7,8 @@ from contextlib import contextmanager
import
time
from
.base
import
Callback
from
.stat
import
*
from
..utils
import
*
from
.stat
import
StatPrinter
from
..utils
import
logger
__all__
=
[
'Callbacks'
]
...
...
tensorpack/callbacks/inference.py
View file @
dfacc168
...
...
@@ -11,10 +11,9 @@ import six
from
six.moves
import
zip
,
map
from
..dataflow
import
DataFlow
from
..utils
import
*
from
..utils.stat
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
*
from
..utils
import
get_tqdm_kwargs
,
logger
from
..utils.stat
import
RatioCounter
,
BinaryStatistics
from
..tfutils
import
get_op_tensor_name
from
.base
import
Callback
__all__
=
[
'InferenceRunner'
,
'ClassificationError'
,
...
...
tensorpack/callbacks/stat.py
View file @
dfacc168
...
...
@@ -8,7 +8,7 @@ import operator
import
json
from
.base
import
Callback
from
..utils
import
*
from
..utils
import
logger
__all__
=
[
'StatHolder'
,
'StatPrinter'
,
'SendStat'
]
...
...
tensorpack/models/_common.py
View file @
dfacc168
...
...
@@ -7,9 +7,9 @@ from functools import wraps
import
six
import
copy
,
os
from
..tfutils
import
*
from
..tfutils.modelutils
import
*
from
..tfutils.summary
import
*
from
..tfutils
.argscope
import
get_arg_scope
from
..tfutils.modelutils
import
get_shape_str
from
..tfutils.summary
import
add_activation_summary
from
..utils
import
logger
# make sure each layer is only logged once
...
...
tensorpack/models/batch_norm.py
View file @
dfacc168
...
...
@@ -7,7 +7,7 @@ import tensorflow as tf
from
copy
import
copy
import
re
from
.
model_desc
import
get_current_tower_context
from
.
.tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
,
EXTRA_SAVE_VARS_KEY
from
._common
import
layer_register
...
...
tensorpack/models/conv2d.py
View file @
dfacc168
...
...
@@ -6,7 +6,7 @@
import
numpy
as
np
import
tensorflow
as
tf
import
math
from
._common
import
*
from
._common
import
layer_register
,
shape2d
,
shape4d
from
..utils
import
map_arg
,
logger
__all__
=
[
'Conv2D'
]
...
...
tensorpack/models/fc.py
View file @
dfacc168
...
...
@@ -7,7 +7,7 @@ import tensorflow as tf
import
math
from
._common
import
layer_register
from
..tfutils
.symbolic_functions
import
*
from
..tfutils
import
symbolic_functions
as
symbf
__all__
=
[
'FullyConnected'
]
...
...
@@ -26,7 +26,7 @@ def FullyConnected(x, out_dim,
:param use_bias: whether to use bias. a boolean default to True
:returns: a 2D tensor
"""
x
=
batch_flatten
(
x
)
x
=
symbf
.
batch_flatten
(
x
)
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
if
W_init
is
None
:
...
...
tensorpack/models/model_desc.py
View file @
dfacc168
...
...
@@ -10,85 +10,13 @@ from collections import namedtuple
import
inspect
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..tfutils
import
*
from
..tfutils.common
import
get_vars_by_names
from
..tfutils.gradproc
import
CheckGradient
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
,
'get_current_tower_context'
,
'TowerContext'
]
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
InputVar
=
namedtuple
(
'InputVar'
,
[
'type'
,
'shape'
,
'name'
])
_CurrentTowerContext
=
None
class
TowerContext
(
object
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
""" tower_name: 'tower0', 'towerp0', or '' """
self
.
_name
=
tower_name
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
'towerp'
)
self
.
_is_training
=
is_training
@
property
def
is_main_training_tower
(
self
):
return
self
.
is_training
and
(
self
.
_name
==
''
or
self
.
_name
==
'tower0'
)
@
property
def
is_main_tower
(
self
):
return
self
.
_name
==
''
or
self
.
_name
==
'tower0'
@
property
def
is_training
(
self
):
return
self
.
_is_training
@
property
def
name
(
self
):
return
self
.
_name
def
get_variable_on_tower
(
self
,
*
args
,
**
kwargs
):
"""
Get a variable for this tower specifically, without reusing.
Tensorflow doesn't allow reuse=False scope under a
reuse=True scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable()
"""
with
tf
.
variable_scope
(
self
.
_name
)
as
scope
:
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
scope
=
tf
.
get_variable_scope
()
assert
scope
.
reuse
==
False
return
tf
.
get_variable
(
*
args
,
**
kwargs
)
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
if
self
.
is_main_tower
:
return
graph
.
get_tensor_by_name
(
name
)
if
name
.
startswith
(
'towerp'
):
newname
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
name
)
try
:
return
graph
.
get_tensor_by_name
(
newname
)
except
KeyError
:
newname
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
name
)
return
graph
.
get_tensor_by_name
(
newname
)
def
__enter__
(
self
):
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
_CurrentTowerContext
=
self
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
return
self
.
_scope
.
__enter__
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
global
_CurrentTowerContext
_CurrentTowerContext
=
None
if
len
(
self
.
_name
):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
def
get_current_tower_context
():
global
_CurrentTowerContext
return
_CurrentTowerContext
class
ModelDesc
(
object
):
""" Base class for a model description """
__metaclass__
=
ABCMeta
...
...
tensorpack/models/nonlin.py
View file @
dfacc168
...
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
from
copy
import
copy
from
._common
import
*
from
._common
import
layer_register
from
.batch_norm
import
BatchNorm
__all__
=
[
'Maxout'
,
'PReLU'
,
'LeakyReLU'
,
'BNReLU'
]
...
...
tensorpack/models/pool.py
View file @
dfacc168
...
...
@@ -5,8 +5,8 @@
import
tensorflow
as
tf
import
numpy
from
._common
import
*
from
..tfutils
.symbolic_functions
import
*
from
._common
import
layer_register
,
shape2d
,
shape4d
from
..tfutils
import
symbolic_functions
as
symbf
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
'BilinearUpSample'
]
...
...
@@ -105,9 +105,9 @@ def FixedUnPooling(x, shape, unpool_mat=None):
assert
unpool_mat
.
get_shape
()
.
as_list
()
==
list
(
shape
)
# perform a tensor-matrix kronecker product
fx
=
flatten
(
tf
.
transpose
(
x
,
[
0
,
3
,
1
,
2
]))
fx
=
symbf
.
flatten
(
tf
.
transpose
(
x
,
[
0
,
3
,
1
,
2
]))
fx
=
tf
.
expand_dims
(
fx
,
-
1
)
# (bchw)x1
mat
=
tf
.
expand_dims
(
flatten
(
unpool_mat
),
0
)
#1x(shxsw)
mat
=
tf
.
expand_dims
(
symbf
.
flatten
(
unpool_mat
),
0
)
#1x(shxsw)
prod
=
tf
.
matmul
(
fx
,
mat
)
#(bchw) x(shxsw)
prod
=
tf
.
reshape
(
prod
,
tf
.
pack
(
[
-
1
,
input_shape
[
3
],
input_shape
[
1
],
input_shape
[
2
],
shape
[
0
],
shape
[
1
]]))
...
...
tensorpack/models/regularize.py
View file @
dfacc168
...
...
@@ -6,8 +6,8 @@ import tensorflow as tf
import
re
from
..utils
import
logger
from
..utils.utils
import
*
from
.
model_desc
import
get_current_tower_context
from
..utils.utils
import
memoized
from
.
.tfutils.tower
import
get_current_tower_context
from
._common
import
layer_register
__all__
=
[
'regularize_cost'
,
'l2_regularizer'
,
'l1_regularizer'
,
'Dropout'
]
...
...
tensorpack/predict/base.py
View file @
dfacc168
...
...
@@ -7,9 +7,8 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import
tensorflow
as
tf
import
six
from
..models
import
TowerContext
from
..utils
import
logger
from
..tfutils
import
get_vars_by_names
from
..tfutils
import
get_vars_by_names
,
TowerContext
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
'AsyncPredictorBase'
,
...
...
tensorpack/predict/common.py
View file @
dfacc168
...
...
@@ -9,7 +9,8 @@ from six.moves import zip
from
tensorpack.models
import
ModelDesc
from
..utils
import
logger
from
..tfutils
import
*
from
..tfutils
import
get_default_sess_config
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
.base
import
OfflinePredictor
import
multiprocessing
...
...
tensorpack/predict/concurrency.py
View file @
dfacc168
...
...
@@ -9,12 +9,9 @@ import time
import
six
from
six.moves
import
queue
,
range
,
zip
from
..utils.concurrency
import
DIE
from
..tfutils.modelutils
import
describe_model
from
..utils
import
logger
from
..utils.timer
import
*
from
..tfutils
import
*
from
.base
import
*
...
...
tensorpack/tfutils/__init__.py
View file @
dfacc168
...
...
@@ -17,4 +17,5 @@ _global_import('sessinit')
_global_import
(
'common'
)
_global_import
(
'gradproc'
)
_global_import
(
'argscope'
)
_global_import
(
'tower'
)
tensorpack/tfutils/modelutils.py
View file @
dfacc168
...
...
@@ -6,6 +6,8 @@ import tensorflow as tf
from
..utils
import
logger
__all__
=
[
'describe_model'
,
'get_shape_str'
]
def
describe_model
():
""" print a description of the current model parameters """
train_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
...
...
tensorpack/tfutils/summary.py
View file @
dfacc168
...
...
@@ -7,6 +7,7 @@ import tensorflow as tf
import
re
from
..utils
import
*
from
.tower
import
get_current_tower_context
from
.
import
get_global_step_var
from
.symbolic_functions
import
rms
...
...
@@ -28,6 +29,8 @@ def add_activation_summary(x, name=None):
Add summary to graph for an activation tensor x.
If name is None, use x.name.
"""
if
not
get_current_tower_context
()
.
is_main_training_tower
:
return
ndim
=
x
.
get_shape
()
.
ndims
assert
ndim
>=
2
,
\
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
...
...
@@ -46,6 +49,8 @@ def add_param_summary(summary_lists):
:param summary_lists: list of (regex, [list of summary type to perform]).
Type can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms'
"""
if
not
get_current_tower_context
()
.
is_main_training_tower
:
return
def
perform
(
var
,
action
):
ndim
=
var
.
get_shape
()
.
ndims
name
=
var
.
name
.
replace
(
':0'
,
''
)
...
...
@@ -84,6 +89,8 @@ def add_moving_summary(v, *args):
:param v: tensor or list of tensor to summary
:param args: tensors to summary
"""
if
not
get_current_tower_context
()
.
is_main_training_tower
:
return
if
not
isinstance
(
v
,
list
):
v
=
[
v
]
v
.
extend
(
args
)
...
...
tensorpack/tfutils/tower.py
0 → 100644
View file @
dfacc168
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tower.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
__all__
=
[
'get_current_tower_context'
,
'TowerContext'
]
_CurrentTowerContext
=
None
class
TowerContext
(
object
):
def
__init__
(
self
,
tower_name
,
is_training
=
None
):
""" tower_name: 'tower0', 'towerp0', or '' """
self
.
_name
=
tower_name
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
'towerp'
)
self
.
_is_training
=
is_training
@
property
def
is_main_training_tower
(
self
):
return
self
.
is_training
and
(
self
.
_name
==
''
or
self
.
_name
==
'tower0'
)
@
property
def
is_main_tower
(
self
):
return
self
.
_name
==
''
or
self
.
_name
==
'tower0'
@
property
def
is_training
(
self
):
return
self
.
_is_training
@
property
def
name
(
self
):
return
self
.
_name
def
get_variable_on_tower
(
self
,
*
args
,
**
kwargs
):
"""
Get a variable for this tower specifically, without reusing.
Tensorflow doesn't allow reuse=False scope under a
reuse=True scope. This method provides a work around.
See https://www.tensorflow.org/versions/master/how_tos/variable_scope/index.html#basics-of-tfvariable-scope
:param args, kwargs: same as tf.get_variable()
"""
with
tf
.
variable_scope
(
self
.
_name
)
as
scope
:
with
tf
.
variable_scope
(
scope
,
reuse
=
False
):
scope
=
tf
.
get_variable_scope
()
assert
scope
.
reuse
==
False
return
tf
.
get_variable
(
*
args
,
**
kwargs
)
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
if
self
.
is_main_tower
:
return
graph
.
get_tensor_by_name
(
name
)
if
name
.
startswith
(
'towerp'
):
newname
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
name
)
try
:
return
graph
.
get_tensor_by_name
(
newname
)
except
KeyError
:
newname
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
name
)
return
graph
.
get_tensor_by_name
(
newname
)
def
__enter__
(
self
):
global
_CurrentTowerContext
assert
_CurrentTowerContext
is
None
,
\
"Nesting TowerContext!"
_CurrentTowerContext
=
self
if
len
(
self
.
_name
):
self
.
_scope
=
tf
.
name_scope
(
self
.
_name
)
return
self
.
_scope
.
__enter__
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
global
_CurrentTowerContext
_CurrentTowerContext
=
None
if
len
(
self
.
_name
):
self
.
_scope
.
__exit__
(
exc_type
,
exc_val
,
exc_tb
)
return
False
def
get_current_tower_context
():
global
_CurrentTowerContext
return
_CurrentTowerContext
tensorpack/train/multigpu.py
View file @
dfacc168
...
...
@@ -7,14 +7,13 @@ import tensorflow as tf
import
itertools
,
re
from
six.moves
import
zip
,
range
from
..models
import
TowerContext
from
..utils
import
logger
from
..utils.naming
import
*
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils.modelutils
import
describe_model
from
..tfutils
import
(
backup_collection
,
restore_collection
,
get_global_step_var
)
get_global_step_var
,
TowerContext
)
from
.trainer
import
QueueInputTrainer
...
...
tensorpack/train/trainer.py
View file @
dfacc168
...
...
@@ -11,10 +11,9 @@ from .base import Trainer
from
..dataflow.common
import
RepeatedData
from
..models
import
TowerContext
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
from
..tfutils
import
(
get_vars_by_names
,
freeze_collection
,
get_global_step_var
)
get_global_step_var
,
TowerContext
)
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.modelutils
import
describe_model
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
...
...
@@ -67,7 +66,7 @@ class SimpleTrainer(Trainer):
with
TowerContext
(
''
):
model
.
build_graph
(
self
.
input_vars
)
cost_var
=
model
.
get_cost
()
# TODO assert scalar
add_moving_summary
(
cost_var
)
add_moving_summary
(
cost_var
)
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
cost_var
)
grads
=
self
.
process_grads
(
grads
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment