Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7b0782d6
Commit
7b0782d6
authored
Oct 31, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Clearly define the argument of `build_graph`
parent
4cc00393
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
13 additions
and
8 deletions
+13
-8
tensorpack/RL/README.md
tensorpack/RL/README.md
+4
-0
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+6
-3
tensorpack/tfutils/export.py
tensorpack/tfutils/export.py
+1
-3
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
tensorpack/trainv1/base.py
tensorpack/trainv1/base.py
+1
-1
No files found.
tensorpack/RL/README.md
0 → 100644
View file @
7b0782d6
## DEPRECATED
Please use gym or other APIs.
tensorpack/graph_builder/model_desc.py
View file @
7b0782d6
...
@@ -9,6 +9,7 @@ import tensorflow as tf
...
@@ -9,6 +9,7 @@ import tensorflow as tf
import
six
import
six
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
..utils.develop
import
log_deprecated
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.tower
import
get_current_tower_context
from
..input_source
import
InputSource
from
..input_source
import
InputSource
...
@@ -96,15 +97,17 @@ class ModelDescBase(object):
...
@@ -96,15 +97,17 @@ class ModelDescBase(object):
Build the whole symbolic graph.
Build the whole symbolic graph.
Args:
Args:
args (
list
[tf.Tensor]): a list of tensors,
args ([tf.Tensor]): a list of tensors,
that match the list of :class:`InputDesc` defined by ``_get_inputs``.
that match
es
the list of :class:`InputDesc` defined by ``_get_inputs``.
"""
"""
if
len
(
args
)
==
1
:
if
len
(
args
)
==
1
:
arg
=
args
[
0
]
arg
=
args
[
0
]
if
isinstance
(
arg
,
InputSource
):
if
isinstance
(
arg
,
InputSource
):
inputs
=
arg
.
get_input_tensors
()
# remove in the future?
inputs
=
arg
.
get_input_tensors
()
# remove in the future?
log_deprecated
(
"build_graph(InputSource)"
,
"Call with tensors in positional args instead."
)
elif
isinstance
(
arg
,
(
list
,
tuple
)):
elif
isinstance
(
arg
,
(
list
,
tuple
)):
inputs
=
arg
inputs
=
arg
log_deprecated
(
"build_graph([Tensor])"
,
"Call with positional args instead."
)
else
:
else
:
inputs
=
[
arg
]
inputs
=
[
arg
]
else
:
else
:
...
@@ -163,7 +166,7 @@ class ModelDesc(ModelDescBase):
...
@@ -163,7 +166,7 @@ class ModelDesc(ModelDescBase):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
_build_graph_get_cost
(
self
,
*
inputs
):
def
_build_graph_get_cost
(
self
,
*
inputs
):
self
.
build_graph
(
inputs
)
self
.
build_graph
(
*
inputs
)
return
self
.
get_cost
()
return
self
.
get_cost
()
def
_build_graph_get_grads
(
self
,
*
inputs
):
def
_build_graph_get_grads
(
self
,
*
inputs
):
...
...
tensorpack/tfutils/export.py
View file @
7b0782d6
...
@@ -89,7 +89,7 @@ class ModelExport(object):
...
@@ -89,7 +89,7 @@ class ModelExport(object):
"""
"""
logger
.
info
(
'[export] build model for
%
s'
%
checkpoint
)
logger
.
info
(
'[export] build model for
%
s'
%
checkpoint
)
with
TowerContext
(
''
,
is_training
=
False
):
with
TowerContext
(
''
,
is_training
=
False
):
self
.
model
.
build_graph
(
self
.
input
)
self
.
model
.
build_graph
(
*
self
.
input
.
get_input_tensors
()
)
self
.
sess
=
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
self
.
sess
=
tf
.
Session
(
config
=
tf
.
ConfigProto
(
allow_soft_placement
=
True
))
# load values from latest checkpoint
# load values from latest checkpoint
...
@@ -129,8 +129,6 @@ class ModelExport(object):
...
@@ -129,8 +129,6 @@ class ModelExport(object):
outputs
=
outputs_signature
,
outputs
=
outputs_signature
,
method_name
=
tf
.
saved_model
.
signature_constants
.
PREDICT_METHOD_NAME
)
method_name
=
tf
.
saved_model
.
signature_constants
.
PREDICT_METHOD_NAME
)
# legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder
.
add_meta_graph_and_variables
(
builder
.
add_meta_graph_and_variables
(
self
.
sess
,
tags
,
self
.
sess
,
tags
,
signature_def_map
=
{
signature_name
:
prediction_signature
})
signature_def_map
=
{
signature_name
:
prediction_signature
})
...
...
tensorpack/train/base.py
View file @
7b0782d6
...
@@ -86,7 +86,7 @@ class Trainer(object):
...
@@ -86,7 +86,7 @@ class Trainer(object):
self
.
_config
=
config
self
.
_config
=
config
self
.
inputs_desc
=
config
.
model
.
get_inputs_desc
()
self
.
inputs_desc
=
config
.
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
self
.
tower_func
=
TowerFuncWrapper
(
lambda
*
inputs
:
config
.
model
.
build_graph
(
inputs
),
lambda
*
inputs
:
config
.
model
.
build_graph
(
*
inputs
),
self
.
inputs_desc
)
self
.
inputs_desc
)
self
.
_main_tower_vs_name
=
""
self
.
_main_tower_vs_name
=
""
...
...
tensorpack/trainv1/base.py
View file @
7b0782d6
...
@@ -123,7 +123,7 @@ class Trainer(object):
...
@@ -123,7 +123,7 @@ class Trainer(object):
if
self
.
model
is
not
None
:
if
self
.
model
is
not
None
:
def
f
(
*
inputs
):
def
f
(
*
inputs
):
self
.
model
.
build_graph
(
inputs
)
self
.
model
.
build_graph
(
*
inputs
)
"""
"""
Only to mimic new trainer interafce on inference.
Only to mimic new trainer interafce on inference.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment