Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d9e7c6bf
Commit
d9e7c6bf
authored
Dec 03, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
allow data / dataset as input to trainconfig
parent
8db5bcd3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
75 additions
and
42 deletions
+75
-42
examples/GAN/GAN.py
examples/GAN/GAN.py
+3
-3
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+1
-1
tensorpack/train/config.py
tensorpack/train/config.py
+15
-3
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+13
-4
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+18
-6
tensorpack/train/queue.py
tensorpack/train/queue.py
+14
-13
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+11
-12
No files found.
examples/GAN/GAN.py
View file @
d9e7c6bf
...
@@ -5,15 +5,15 @@
...
@@ -5,15 +5,15 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
numpy
as
np
import
numpy
as
np
from
tensorpack
import
(
QueueInput
Trainer
,
TowerContext
,
from
tensorpack
import
(
Feedfree
Trainer
,
TowerContext
,
get_global_step_var
,
QueueInput
)
get_global_step_var
,
QueueInput
)
from
tensorpack.tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
tensorpack.tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.dataflow
import
DataFlow
class
GANTrainer
(
QueueInput
Trainer
):
class
GANTrainer
(
Feedfree
Trainer
):
def
__init__
(
self
,
config
,
g_vs_d
=
1
):
def
__init__
(
self
,
config
,
g_vs_d
=
1
):
super
(
GANTrainer
,
self
)
.
__init__
(
config
)
self
.
_input_method
=
QueueInput
(
config
.
dataset
)
self
.
_input_method
=
QueueInput
(
config
.
dataset
)
super
(
GANTrainer
,
self
)
.
__init__
(
config
)
if
g_vs_d
>
1
:
if
g_vs_d
>
1
:
self
.
_opt_g
=
g_vs_d
self
.
_opt_g
=
g_vs_d
self
.
_opt_d
=
1
self
.
_opt_d
=
1
...
...
examples/GAN/InfoGAN-mnist.py
View file @
d9e7c6bf
...
@@ -82,7 +82,7 @@ class Model(ModelDesc):
...
@@ -82,7 +82,7 @@ class Model(ModelDesc):
self
.
g_loss
,
self
.
d_loss
=
build_GAN_losses
(
vecpos
,
vecneg
)
self
.
g_loss
,
self
.
d_loss
=
build_GAN_losses
(
vecpos
,
vecneg
)
self
.
g_loss
=
tf
.
add
(
self
.
g_loss
,
MIloss
,
name
=
'total_g_loss'
)
self
.
g_loss
=
tf
.
add
(
self
.
g_loss
,
MIloss
,
name
=
'total_g_loss'
)
self
.
d_loss
=
tf
.
add
(
self
.
d_loss
,
MIloss
,
name
=
'total_
g
_loss'
)
self
.
d_loss
=
tf
.
add
(
self
.
d_loss
,
MIloss
,
name
=
'total_
d
_loss'
)
summary
.
add_moving_summary
(
MIloss
,
self
.
g_loss
,
self
.
d_loss
,
Hc
,
Elog_qc_given_x
)
summary
.
add_moving_summary
(
MIloss
,
self
.
g_loss
,
self
.
d_loss
,
Hc
,
Elog_qc_given_x
)
all_vars
=
tf
.
trainable_variables
()
all_vars
=
tf
.
trainable_variables
()
...
...
tensorpack/train/config.py
View file @
d9e7c6bf
...
@@ -10,6 +10,7 @@ from ..utils import logger
...
@@ -10,6 +10,7 @@ from ..utils import logger
from
..tfutils
import
(
JustCurrentSession
,
from
..tfutils
import
(
JustCurrentSession
,
get_default_sess_config
,
SessionInit
)
get_default_sess_config
,
SessionInit
)
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
.input_data
import
InputData
__all__
=
[
'TrainConfig'
]
__all__
=
[
'TrainConfig'
]
...
@@ -20,6 +21,7 @@ class TrainConfig(object):
...
@@ -20,6 +21,7 @@ class TrainConfig(object):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
"""
"""
:param dataset: the dataset to train. a `DataFlow` instance.
:param dataset: the dataset to train. a `DataFlow` instance.
:param data: an `InputData` instance
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param callbacks: a `callback.Callbacks` instance. Define
:param callbacks: a `callback.Callbacks` instance. Define
the callbacks to perform during training.
the callbacks to perform during training.
...
@@ -35,8 +37,14 @@ class TrainConfig(object):
...
@@ -35,8 +37,14 @@ class TrainConfig(object):
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
dataset
=
kwargs
.
pop
(
'dataset'
)
if
'dataset'
in
kwargs
:
assert_type
(
self
.
dataset
,
DataFlow
)
assert
'data'
not
in
kwargs
,
"dataset and data cannot be both presented in TrainConfig!"
self
.
dataset
=
kwargs
.
pop
(
'dataset'
)
assert_type
(
self
.
dataset
,
DataFlow
)
else
:
self
.
data
=
kwargs
.
pop
(
'data'
)
assert_type
(
self
.
data
,
InputData
)
self
.
optimizer
=
kwargs
.
pop
(
'optimizer'
)
self
.
optimizer
=
kwargs
.
pop
(
'optimizer'
)
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
self
.
callbacks
=
kwargs
.
pop
(
'callbacks'
)
self
.
callbacks
=
kwargs
.
pop
(
'callbacks'
)
...
@@ -52,7 +60,10 @@ class TrainConfig(object):
...
@@ -52,7 +60,10 @@ class TrainConfig(object):
self
.
step_per_epoch
=
kwargs
.
pop
(
'step_per_epoch'
,
None
)
self
.
step_per_epoch
=
kwargs
.
pop
(
'step_per_epoch'
,
None
)
if
self
.
step_per_epoch
is
None
:
if
self
.
step_per_epoch
is
None
:
try
:
try
:
self
.
step_per_epoch
=
self
.
dataset
.
size
()
if
hasattr
(
self
,
'dataset'
):
self
.
step_per_epoch
=
self
.
dataset
.
size
()
else
:
self
.
step_per_epoch
=
self
.
data
.
size
()
except
NotImplementedError
:
except
NotImplementedError
:
logger
.
exception
(
"You must set `step_per_epoch` if dataset.size() is not implemented."
)
logger
.
exception
(
"You must set `step_per_epoch` if dataset.size() is not implemented."
)
else
:
else
:
...
@@ -70,6 +81,7 @@ class TrainConfig(object):
...
@@ -70,6 +81,7 @@ class TrainConfig(object):
else
:
else
:
self
.
tower
=
[
0
]
self
.
tower
=
[
0
]
# TODO deprecated @Dec20
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
if
self
.
extra_threads_procs
:
if
self
.
extra_threads_procs
:
logger
.
warn
(
"[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs"
)
logger
.
warn
(
"[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs"
)
...
...
tensorpack/train/input
method
.py
→
tensorpack/train/input
_data
.py
View file @
d9e7c6bf
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: input
method
.py
# File: input
_data
.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
import
threading
import
threading
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
..dataflow.common
import
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
add_moving_summary
from
..utils
import
logger
from
..utils
import
logger
from
..callbacks.concurrency
import
StartProcOrThread
from
..callbacks.concurrency
import
StartProcOrThread
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
]
__all__
=
[
'QueueInput'
,
'FeedfreeInput'
,
'TensorInput'
]
class
Input
Method
(
object
):
class
Input
Data
(
object
):
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
pass
pass
class
FeedInput
(
Input
Method
):
class
FeedInput
(
Input
Data
):
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
self
.
ds
=
ds
self
.
ds
=
ds
...
@@ -26,8 +27,16 @@ class FeedInput(InputMethod):
...
@@ -26,8 +27,16 @@ class FeedInput(InputMethod):
def
_setup
(
self
,
trainer
):
def
_setup
(
self
,
trainer
):
self
.
input_vars
=
trainer
.
model
.
get_input_vars
()
self
.
input_vars
=
trainer
.
model
.
get_input_vars
()
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
rds
.
reset_state
()
self
.
data_producer
=
rds
.
get_data
()
class
FeedfreeInput
(
InputMethod
):
def
next_feed
(
self
):
data
=
next
(
self
.
data_producer
)
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
return
feed
class
FeedfreeInput
(
InputData
):
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
return
self
.
_get_input_tensors
()
return
self
.
_get_input_tensors
()
...
...
tensorpack/train/multigpu.py
View file @
d9e7c6bf
...
@@ -17,7 +17,7 @@ from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
...
@@ -17,7 +17,7 @@ from ..tfutils.gradproc import apply_grad_processors, ScaleGradient
from
.trainer
import
FeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
from
.trainer
import
FeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
from
.queue
import
QueueInputTrainer
from
.queue
import
QueueInputTrainer
from
.input
method
import
QueueInput
from
.input
_data
import
QueueInput
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
__all__
=
[
'AsyncMultiGPUTrainer'
,
'SyncMultiGPUTrainer'
]
...
@@ -47,10 +47,16 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -47,10 +47,16 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
else
:
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
@
staticmethod
@
staticmethod
def
_average_grads
(
tower_grads
):
def
_average_grads
(
tower_grads
):
...
@@ -95,17 +101,23 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -95,17 +101,23 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
input_queue
=
None
,
input_queue
=
None
,
predict_tower
=
None
,
predict_tower
=
None
,
average_gradient
=
True
):
average_gradient
=
True
):
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
else
:
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_
input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
self
.
_
average_gradient
=
average_gradient
self
.
average_gradient
=
average_gradient
def
_setup
(
self
):
def
_setup
(
self
):
super
(
SyncMultiGPUTrainer
,
self
)
.
_setup
()
super
(
SyncMultiGPUTrainer
,
self
)
.
_setup
()
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
gradprocs
=
self
.
model
.
get_gradient_processor
()
gradprocs
=
self
.
model
.
get_gradient_processor
()
if
self
.
average_gradient
and
self
.
config
.
nr_tower
>
1
:
if
self
.
_
average_gradient
and
self
.
config
.
nr_tower
>
1
:
# pretend to average the grads, in order to make async and
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
# sync have consistent effective learning rate
gradprocs
.
insert
(
0
,
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
log
=
False
))
gradprocs
.
insert
(
0
,
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
log
=
False
))
...
...
tensorpack/train/queue.py
View file @
d9e7c6bf
...
@@ -3,19 +3,16 @@
...
@@ -3,19 +3,16 @@
# File: queue.py
# File: queue.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
threading
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..dataflow.common
import
RepeatedData
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils
import
get_global_step_var
,
TowerContext
from
..utils
import
logger
from
..utils
import
logger
from
..callbacks.concurrency
import
StartProcOrThread
from
..tfutils
import
get_global_step_var
from
..tfutils.tower
import
TowerContext
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
.inputmethod
import
QueueInput
from
..tfutils.summary
import
summary_moving_average
from
.input_data
import
QueueInput
from
.trainer
import
(
FeedfreeTrainer
,
MultiPredictorTowerTrainer
,
from
.trainer
import
(
MultiPredictorTowerTrainer
,
SingleCostFeedfreeTrainer
)
SingleCostFeedfreeTrainer
)
__all__
=
[
'QueueInputTrainer'
]
__all__
=
[
'QueueInputTrainer'
]
...
@@ -30,14 +27,19 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
...
@@ -30,14 +27,19 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu.
Use -1 for cpu.
"""
"""
if
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
else
:
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
QueueInput
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
predict_tower
)
self
.
_input_method
=
QueueInput
(
config
.
dataset
,
input_queue
)
def
_setup
(
self
):
self
.
_setup_predictor_factory
(
predict_tower
)
super
(
QueueInputTrainer
,
self
)
.
_setup
()
assert
len
(
self
.
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
def
_setup
(
self
):
super
(
SingleCostFeedfreeTrainer
,
self
)
.
_setup
()
with
TowerContext
(
''
):
with
TowerContext
(
''
):
cost
,
grads
=
self
.
_get_cost_and_grad
()
cost
,
grads
=
self
.
_get_cost_and_grad
()
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
...
@@ -47,4 +49,3 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
...
@@ -47,4 +49,3 @@ class QueueInputTrainer(MultiPredictorTowerTrainer, SingleCostFeedfreeTrainer):
summary_moving_average
(),
name
=
'train_op'
)
summary_moving_average
(),
name
=
'train_op'
)
# skip training
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
#self.train_op = tf.group(*self.dequed_inputs)
tensorpack/train/trainer.py
View file @
d9e7c6bf
...
@@ -8,15 +8,13 @@ from six.moves import zip
...
@@ -8,15 +8,13 @@ from six.moves import zip
from
.base
import
Trainer
from
.base
import
Trainer
from
..dataflow.common
import
RepeatedData
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
.input
method
import
FeedfreeInput
from
.input
_data
import
FeedInput
,
FeedfreeInput
__all__
=
[
'SimpleTrainer'
,
'FeedfreeTrainer'
,
'MultiPredictorTowerTrainer'
,
__all__
=
[
'SimpleTrainer'
,
'FeedfreeTrainer'
,
'MultiPredictorTowerTrainer'
,
'SingleCostFeedfreeTrainer'
]
'SingleCostFeedfreeTrainer'
]
...
@@ -59,13 +57,18 @@ class SimpleTrainer(Trainer):
...
@@ -59,13 +57,18 @@ class SimpleTrainer(Trainer):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
if
not
hasattr
(
config
,
'dataset'
):
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
FeedInput
)
else
:
self
.
_input_method
=
FeedInput
(
config
.
dataset
)
def
run_step
(
self
):
def
run_step
(
self
):
data
=
next
(
self
.
data_producer
)
feed
=
self
.
_input_method
.
next_feed
()
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
def
_setup
(
self
):
def
_setup
(
self
):
self
.
_input_method
.
_setup
(
self
)
model
=
self
.
model
model
=
self
.
model
self
.
input_vars
=
model
.
get_input_vars
()
self
.
input_vars
=
model
.
get_input_vars
()
with
TowerContext
(
''
):
with
TowerContext
(
''
):
...
@@ -81,14 +84,9 @@ class SimpleTrainer(Trainer):
...
@@ -81,14 +84,9 @@ class SimpleTrainer(Trainer):
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
summary_moving_average
(),
name
=
'train_op'
)
# create an infinte data producer
self
.
config
.
dataset
.
reset_state
()
self
.
data_producer
=
RepeatedData
(
self
.
config
.
dataset
,
-
1
)
.
get_data
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
summary_op
is
not
None
:
if
self
.
summary_op
is
not
None
:
data
=
next
(
self
.
data_producer
)
feed
=
self
.
_input_method
.
next_feed
()
feed
=
dict
(
zip
(
self
.
input_vars
,
data
))
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
summary_str
=
self
.
summary_op
.
eval
(
feed_dict
=
feed
)
self
.
_process_summary
(
summary_str
)
self
.
_process_summary
(
summary_str
)
...
@@ -126,7 +124,7 @@ class FeedfreeTrainer(Trainer):
...
@@ -126,7 +124,7 @@ class FeedfreeTrainer(Trainer):
return
self
.
_input_method
.
get_input_tensors
()
return
self
.
_input_method
.
get_input_tensors
()
def
_setup
(
self
):
def
_setup
(
self
):
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
)
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
)
,
type
(
self
.
_input_method
)
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
class
SingleCostFeedfreeTrainer
(
FeedfreeTrainer
):
...
@@ -155,3 +153,4 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
...
@@ -155,3 +153,4 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
#trace_file = open('timeline.ctf.json', 'w')
#trace_file = open('timeline.ctf.json', 'w')
#trace_file.write(trace.generate_chrome_trace_format())
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
#import sys; sys.exit()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment