Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a349c558
Commit
a349c558
authored
Feb 23, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
small internal rename
parent
6c68f8aa
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
36 additions
and
31 deletions
+36
-31
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+3
-3
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+1
-1
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+28
-22
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+4
-5
No files found.
tensorpack/callbacks/inference_runner.py
View file @
a349c558
...
@@ -12,7 +12,7 @@ from ..utils import logger, get_tqdm
...
@@ -12,7 +12,7 @@ from ..utils import logger, get_tqdm
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils
import
TowerContext
from
..tfutils
import
TowerContext
from
..train.input_data
import
Feedfree
Input
from
..train.input_data
import
Tensor
Input
from
..predict
import
PredictorTowerBuilder
from
..predict
import
PredictorTowerBuilder
from
.base
import
Triggerable
from
.base
import
Triggerable
...
@@ -161,7 +161,7 @@ class FeedfreeInferenceRunner(Triggerable):
...
@@ -161,7 +161,7 @@ class FeedfreeInferenceRunner(Triggerable):
prefix(str): an prefix used to build the tower. Must be set
prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used.
differently if more than one :class:`FeedfreeInferenceRunner` are used.
"""
"""
assert
isinstance
(
input
,
Feedfree
Input
),
input
assert
isinstance
(
input
,
Tensor
Input
),
input
self
.
_input_data
=
input
self
.
_input_data
=
input
if
not
isinstance
(
infs
,
list
):
if
not
isinstance
(
infs
,
list
):
self
.
infs
=
[
infs
]
self
.
infs
=
[
infs
]
...
@@ -192,7 +192,7 @@ class FeedfreeInferenceRunner(Triggerable):
...
@@ -192,7 +192,7 @@ class FeedfreeInferenceRunner(Triggerable):
self
.
_find_output_tensors
()
self
.
_find_output_tensors
()
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
self
.
_input_data
.
_setup
(
self
.
trainer
)
self
.
_input_data
.
setup
(
self
.
trainer
.
model
)
# only 1 prediction tower will be used for inference
# only 1 prediction tower will be used for inference
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
model_placehdrs
=
self
.
trainer
.
model
.
get_reused_placehdrs
()
model_placehdrs
=
self
.
trainer
.
model
.
get_reused_placehdrs
()
...
...
tensorpack/train/feedfree.py
View file @
a349c558
...
@@ -36,7 +36,7 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -36,7 +36,7 @@ class FeedfreeTrainerBase(Trainer):
def
_setup
(
self
):
def
_setup
(
self
):
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
type
(
self
.
_input_method
)
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
setup_training
(
self
)
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run ``self.train_op``."""
""" Simply run ``self.train_op``."""
...
...
tensorpack/train/input_data.py
View file @
a349c558
...
@@ -22,8 +22,13 @@ __all__ = ['InputData', 'FeedfreeInput',
...
@@ -22,8 +22,13 @@ __all__ = ['InputData', 'FeedfreeInput',
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
InputData
(
object
):
class
InputData
(
object
):
""" Base class for the abstract InputData. """
""" Base class for the abstract InputData. """
def
setup
(
self
,
model
):
pass
pass
def
setup_training
(
self
,
trainer
):
self
.
setup
(
trainer
.
model
)
class
FeedInput
(
InputData
):
class
FeedInput
(
InputData
):
""" Input by iterating over a DataFlow and feed datapoints. """
""" Input by iterating over a DataFlow and feed datapoints. """
...
@@ -38,8 +43,8 @@ class FeedInput(InputData):
...
@@ -38,8 +43,8 @@ class FeedInput(InputData):
def
size
(
self
):
def
size
(
self
):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
def
_setup
(
self
,
trainer
):
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
trainer
.
model
.
get_reused_placehdrs
()
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
rds
.
reset_state
()
rds
.
reset_state
()
self
.
data_producer
=
rds
.
get_data
()
self
.
data_producer
=
rds
.
get_data
()
...
@@ -58,18 +63,16 @@ class FeedfreeInput(InputData):
...
@@ -58,18 +63,16 @@ class FeedfreeInput(InputData):
""" Abstract base for input without feed,
""" Abstract base for input without feed,
e.g. by queue or other operations. """
e.g. by queue or other operations. """
@
abstractmethod
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
"""
"""
Returns:
Returns:
list: A list of tensors corresponding to the inputs of the model.
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
"""
"""
return
self
.
_get_input_tensors
()
@
abstractmethod
def
get_client_threads
(
self
):
def
_get_input_tensors
(
self
):
return
[]
"""
always create and return a list of new input tensors
"""
class
EnqueueThread
(
ShareSessionThread
):
class
EnqueueThread
(
ShareSessionThread
):
...
@@ -125,18 +128,21 @@ class QueueInput(FeedfreeInput):
...
@@ -125,18 +128,21 @@ class QueueInput(FeedfreeInput):
return
self
.
ds
.
size
()
return
self
.
ds
.
size
()
# TODO XXX use input data mapping. not all placeholders are needed
# TODO XXX use input data mapping. not all placeholders are needed
def
_setup
(
self
,
trainer
):
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
trainer
.
model
.
get_reused_placehdrs
()
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput
can only
be used with input placeholders!"
"QueueInput
has to
be used with input placeholders!"
if
self
.
queue
is
None
:
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
self
.
queue
=
tf
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
50
,
[
x
.
dtype
for
x
in
self
.
input_placehdrs
],
name
=
'input_queue'
)
name
=
'input_queue'
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
self
.
input_placehdrs
)
def
setup_training
(
self
,
trainer
):
self
.
setup
(
trainer
.
model
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
def
_
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
ret
=
self
.
queue
.
dequeue
(
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
ret
=
[
ret
]
...
@@ -166,10 +172,10 @@ class BatchQueueInput(FeedfreeInput):
...
@@ -166,10 +172,10 @@ class BatchQueueInput(FeedfreeInput):
def
size
(
self
):
def
size
(
self
):
return
self
.
ds
.
size
()
//
self
.
batch_size
return
self
.
ds
.
size
()
//
self
.
batch_size
def
_setup
(
self
,
trainer
):
def
setup
(
self
,
model
):
self
.
input_placehdrs
=
trainer
.
model
.
get_reused_placehdrs
()
self
.
input_placehdrs
=
model
.
get_reused_placehdrs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"
QueueInput can only
be used with input placeholders!"
"
BatchQueueInput has to
be used with input placeholders!"
# prepare placeholders without the first dimension
# prepare placeholders without the first dimension
placehdrs_nobatch
=
[]
placehdrs_nobatch
=
[]
...
@@ -195,9 +201,12 @@ class BatchQueueInput(FeedfreeInput):
...
@@ -195,9 +201,12 @@ class BatchQueueInput(FeedfreeInput):
assert
shp
.
is_fully_defined
(),
shape_err
assert
shp
.
is_fully_defined
(),
shape_err
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
self
.
thread
=
EnqueueThread
(
self
.
queue
,
self
.
ds
,
placehdrs_nobatch
)
def
setup_training
(
self
,
trainer
):
self
.
setup
(
trainer
.
model
)
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
trainer
.
config
.
callbacks
.
append
(
StartProcOrThread
(
self
.
thread
))
def
_
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
ret
=
self
.
queue
.
dequeue_many
(
self
.
batch_size
,
name
=
'input_deque'
)
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
ret
=
[
ret
]
...
@@ -221,7 +230,7 @@ class DummyConstantInput(FeedfreeInput):
...
@@ -221,7 +230,7 @@ class DummyConstantInput(FeedfreeInput):
self
.
shapes
=
shapes
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
logger
.
warn
(
"Using dummy input for debug!"
)
def
_
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
placehdrs
=
self
.
input_placehdrs
placehdrs
=
self
.
input_placehdrs
assert
len
(
self
.
shapes
)
==
len
(
placehdrs
)
assert
len
(
self
.
shapes
)
==
len
(
placehdrs
)
ret
=
[]
ret
=
[]
...
@@ -253,8 +262,5 @@ class TensorInput(FeedfreeInput):
...
@@ -253,8 +262,5 @@ class TensorInput(FeedfreeInput):
raise
NotImplementedError
(
"size of TensorInput is undefined!"
)
raise
NotImplementedError
(
"size of TensorInput is undefined!"
)
return
self
.
_size
return
self
.
_size
def
_setup
(
self
,
trainer
):
def
get_input_tensors
(
self
):
pass
def
_get_input_tensors
(
self
):
return
self
.
get_tensor_fn
()
return
self
.
get_tensor_fn
()
tensorpack/train/trainer.py
View file @
a349c558
...
@@ -32,13 +32,12 @@ class SimpleTrainer(Trainer):
...
@@ -32,13 +32,12 @@ class SimpleTrainer(Trainer):
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
self
.
hooked_sess
.
run
(
self
.
train_op
,
feed_dict
=
feed
)
def
_setup
(
self
):
def
_setup
(
self
):
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
setup_training
(
self
)
model
=
self
.
model
model
=
self
.
model
self
.
input
_var
s
=
model
.
get_reused_placehdrs
()
self
.
inputs
=
model
.
get_reused_placehdrs
()
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
self
.
input
_var
s
)
model
.
build_graph
(
self
.
inputs
)
cost_var
=
model
.
get_cost
()
cost_var
=
model
.
get_cost
()
opt
=
self
.
config
.
optimizer
opt
=
self
.
config
.
optimizer
grads
=
opt
.
compute_gradients
(
cost_var
)
self
.
train_op
=
opt
.
minimize
(
cost_var
,
name
=
'min_op'
)
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment