Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2ce43d70
Commit
2ce43d70
authored
Jan 01, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make code importable under tf2
parent
7b4980c9
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
36 additions
and
29 deletions
+36
-29
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+3
-1
tensorpack/callbacks/hooks.py
tensorpack/callbacks/hooks.py
+2
-1
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-1
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+3
-1
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+3
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+1
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+16
-2
tensorpack/tfutils/optimizer.py
tensorpack/tfutils/optimizer.py
+2
-2
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+4
-19
No files found.
tensorpack/callbacks/graph.py
View file @
2ce43d70
...
...
@@ -80,11 +80,13 @@ class RunUpdateOps(RunOp):
each `sess.run` call.
"""
def
__init__
(
self
,
collection
=
tf
.
GraphKeys
.
UPDATE_OPS
):
def
__init__
(
self
,
collection
=
None
):
"""
Args:
collection (str): collection of ops to run. Defaults to ``tf.GraphKeys.UPDATE_OPS``
"""
if
collection
is
None
:
collection
=
tf
.
GraphKeys
.
UPDATE_OPS
name
=
'UPDATE_OPS'
if
collection
==
tf
.
GraphKeys
.
UPDATE_OPS
else
collection
def
f
():
...
...
tensorpack/callbacks/hooks.py
View file @
2ce43d70
...
...
@@ -6,12 +6,13 @@
import
tensorflow
as
tf
from
..tfutils.common
import
tfv1
from
.base
import
Callback
__all__
=
[
'CallbackToHook'
,
'HookToCallback'
]
class
CallbackToHook
(
tf
.
train
.
SessionRunHook
):
class
CallbackToHook
(
tf
v1
.
train
.
SessionRunHook
):
""" This is only for internal implementation of
before_run/after_run callbacks.
You shouldn't need to use this.
...
...
tensorpack/callbacks/inference_runner.py
View file @
2ce43d70
...
...
@@ -13,6 +13,7 @@ from tensorflow.python.training.monitored_session import _HookedSession as Hooke
from
..dataflow.base
import
DataFlow
from
..input_source
import
FeedInput
,
InputSource
,
QueueInput
,
StagingInput
from
..tfutils.tower
import
PredictTowerContext
from
..tfutils.common
import
tfv1
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
.base
import
Callback
...
...
@@ -27,7 +28,7 @@ def _device_from_int(dev):
return
'/gpu:{}'
.
format
(
dev
)
if
dev
>=
0
else
'/cpu:0'
class
InferencerToHook
(
tf
.
train
.
SessionRunHook
):
class
InferencerToHook
(
tf
v1
.
train
.
SessionRunHook
):
def
__init__
(
self
,
inf
,
fetches
):
self
.
_inf
=
inf
self
.
_fetches
=
fetches
...
...
tensorpack/callbacks/saver.py
View file @
2ce43d70
...
...
@@ -20,7 +20,7 @@ class ModelSaver(Callback):
def
__init__
(
self
,
max_to_keep
=
10
,
keep_checkpoint_every_n_hours
=
0.5
,
checkpoint_dir
=
None
,
var_collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
]
):
var_collections
=
None
):
"""
Args:
max_to_keep (int): the same as in ``tf.train.Saver``.
...
...
@@ -29,6 +29,8 @@ class ModelSaver(Callback):
checkpoint_dir (str): Defaults to ``logger.get_logger_dir()``.
var_collections (str or list of str): collection of the variables (or list of collections) to save.
"""
if
var_collections
is
None
:
var_collections
=
[
tf
.
GraphKeys
.
GLOBAL_VARIABLES
]
self
.
_max_to_keep
=
max_to_keep
self
.
_keep_every_n_hours
=
keep_checkpoint_every_n_hours
...
...
tensorpack/callbacks/summary.py
View file @
2ce43d70
...
...
@@ -116,7 +116,7 @@ class MergeAllSummaries_RunWithOp(Callback):
self
.
trainer
.
monitors
.
put_summary
(
summary
)
def
MergeAllSummaries
(
period
=
0
,
run_alone
=
False
,
key
=
tf
.
GraphKeys
.
SUMMARIES
):
def
MergeAllSummaries
(
period
=
0
,
run_alone
=
False
,
key
=
None
):
"""
This callback is enabled by default.
Evaluate all summaries by `tf.summary.merge_all`, and write them to logs.
...
...
@@ -133,6 +133,8 @@ def MergeAllSummaries(period=0, run_alone=False, key=tf.GraphKeys.SUMMARIES):
key (str): the collection of summary tensors. Same as in `tf.summary.merge_all`.
Default is ``tf.GraphKeys.SUMMARIES``.
"""
if
key
is
None
:
key
=
tf
.
GraphKeys
.
SUMMARIES
period
=
int
(
period
)
if
run_alone
:
return
MergeAllSummaries_RunAlone
(
period
,
key
)
...
...
tensorpack/models/regularize.py
View file @
2ce43d70
...
...
@@ -39,7 +39,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
Args:
regex (str): a regex to match variable names, e.g. "conv.*/W"
func: the regularization function, which takes a tensor and returns a scalar tensor.
E.g., ``tf.
contrib.layers.l2_regularizer
``.
E.g., ``tf.
nn.l2_loss, tf.contrib.layers.l1_regularizer(0.001)
``.
Returns:
tf.Tensor: a scalar, the total regularization cost.
...
...
tensorpack/tfutils/common.py
View file @
2ce43d70
...
...
@@ -150,11 +150,25 @@ def gpu_available_in_session():
@
deprecated
(
"Use get_tf_version_tuple instead."
,
"2019-01-31"
)
def
get_tf_version_number
():
return
float
(
'.'
.
join
(
tf
.
VERSION
.
split
(
'.'
)[:
2
]))
return
float
(
'.'
.
join
(
tf
.
__version__
.
split
(
'.'
)[:
2
]))
def
get_tf_version_tuple
():
"""
Return TensorFlow version as a 2-element tuple (for comparison).
"""
return
tuple
(
map
(
int
,
tf
.
VERSION
.
split
(
'.'
)[:
2
]))
return
tuple
(
map
(
int
,
tf
.
__version__
.
split
(
'.'
)[:
2
]))
def
is_tf2
():
try
:
from
tensorflow.python
import
tf2
return
tf2
.
enabled
()
except
Exception
:
return
False
if
is_tf2
():
tfv1
=
tf
.
compat
.
v1
else
:
tfv1
=
tf
tensorpack/tfutils/optimizer.py
View file @
2ce43d70
...
...
@@ -5,7 +5,7 @@
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.common
import
get_tf_version_tuple
,
tfv1
from
..utils.develop
import
HIDE_DOC
from
.gradproc
import
FilterNoneGrad
,
GradientProcessor
...
...
@@ -14,7 +14,7 @@ __all__ = ['apply_grad_processors', 'ProxyOptimizer',
'AccumGradOptimizer'
]
class
ProxyOptimizer
(
tf
.
train
.
Optimizer
):
class
ProxyOptimizer
(
tf
v1
.
train
.
Optimizer
):
"""
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
"""
...
...
tensorpack/tfutils/sesscreate.py
View file @
2ce43d70
...
...
@@ -4,10 +4,11 @@
import
tensorflow
as
tf
from
..tfutils.common
import
tfv1
from
..utils
import
logger
from
.common
import
get_default_sess_config
__all__
=
[
'NewSessionCreator'
,
'ReuseSessionCreator'
,
'SessionCreatorAdapter'
]
__all__
=
[
'NewSessionCreator'
,
'ReuseSessionCreator'
]
"""
A SessionCreator should:
...
...
@@ -18,7 +19,7 @@ A SessionCreator should:
"""
class
NewSessionCreator
(
tf
.
train
.
SessionCreator
):
class
NewSessionCreator
(
tf
v1
.
train
.
SessionCreator
):
def
__init__
(
self
,
target
=
''
,
config
=
None
):
"""
Args:
...
...
@@ -47,7 +48,7 @@ bugs. See https://github.com/tensorpack/tensorpack/issues/497 for workarounds.")
return
sess
class
ReuseSessionCreator
(
tf
.
train
.
SessionCreator
):
class
ReuseSessionCreator
(
tf
v1
.
train
.
SessionCreator
):
def
__init__
(
self
,
sess
):
"""
Args:
...
...
@@ -57,19 +58,3 @@ class ReuseSessionCreator(tf.train.SessionCreator):
def
create_session
(
self
):
return
self
.
sess
class
SessionCreatorAdapter
(
tf
.
train
.
SessionCreator
):
def
__init__
(
self
,
session_creator
,
func
):
"""
Args:
session_creator (tf.train.SessionCreator): a session creator
func (tf.Session -> tf.Session): takes a session created by
``session_creator``, and return a new session to be returned by ``self.create_session``
"""
self
.
_creator
=
session_creator
self
.
_func
=
func
def
create_session
(
self
):
sess
=
self
.
_creator
.
create_session
()
return
self
.
_func
(
sess
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment