Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4cd01111
Commit
4cd01111
authored
Jan 23, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[WIP] initial commit of step callbacks. should be compatible with the old examples.
parent
0652d859
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
96 additions
and
44 deletions
+96
-44
examples/GAN/GAN.py
examples/GAN/GAN.py
+9
-5
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+43
-0
tensorpack/train/base.py
tensorpack/train/base.py
+5
-3
tensorpack/train/config.py
tensorpack/train/config.py
+23
-11
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+4
-5
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+5
-13
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+7
-7
No files found.
examples/GAN/GAN.py
View file @
4cd01111
...
@@ -79,24 +79,28 @@ class GANTrainer(FeedfreeTrainerBase):
...
@@ -79,24 +79,28 @@ class GANTrainer(FeedfreeTrainerBase):
with
TowerContext
(
''
):
with
TowerContext
(
''
):
actual_inputs
=
self
.
_get_input_tensors
()
actual_inputs
=
self
.
_get_input_tensors
()
self
.
model
.
build_graph
(
actual_inputs
)
self
.
model
.
build_graph
(
actual_inputs
)
# optimize G
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
self
.
model
.
g_loss
,
var_list
=
self
.
model
.
g_vars
)
self
.
model
.
g_loss
,
var_list
=
self
.
model
.
g_vars
)
grads
=
apply_grad_processors
(
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor_g
())
grads
,
self
.
model
.
get_gradient_processor_g
())
self
.
g_min
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'g_op'
)
self
.
g_min
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'g_op'
)
# optimize D
with
tf
.
control_dependencies
([
self
.
g_min
]):
with
tf
.
control_dependencies
([
self
.
g_min
]):
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
grads
=
self
.
config
.
optimizer
.
compute_gradients
(
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
)
self
.
model
.
d_loss
,
var_list
=
self
.
model
.
d_vars
)
grads
=
apply_grad_processors
(
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor_d
())
grads
,
self
.
model
.
get_gradient_processor_d
())
self
.
d_min
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'd_op'
)
self
.
d_min
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
(),
name
=
'd_op'
)
self
.
gs_incr
=
tf
.
assign_add
(
get_global_step_var
(),
1
,
name
=
'global_step_incr'
)
self
.
train_op
=
self
.
d_min
self
.
summary_op
=
summary_moving_average
()
self
.
train_op
=
tf
.
group
(
self
.
d_min
,
self
.
summary_op
,
self
.
gs_incr
)
def
run_step
(
self
):
def
run_step
(
self
):
self
.
sess
.
run
(
self
.
train_op
)
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
extra_fetches
)
return
ret
[
1
:]
class
RandomZData
(
DataFlow
):
class
RandomZData
(
DataFlow
):
...
...
tensorpack/callbacks/steps.py
0 → 100644
View file @
4cd01111
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: steps.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Some common step callbacks. """
from
six.moves
import
zip
from
..utils
import
logger
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.summary
import
summary_moving_average
from
.base
import
Callback
__all__
=
[
'StepStatPrinter'
,
'SummaryMovingAverage'
]
class
StepStatPrinter
(
Callback
):
""" It prints the value of some tensors in each step.
It's just a demo of how trigger_step works but you should in general use
:func:`print_stat` or :func:`tf.Print` instead. """
def
__init__
(
self
,
names
):
names
=
[
get_op_tensor_name
(
n
)[
1
]
for
n
in
names
]
logger
.
warn
(
"Using print_stat or tf.Print in the graph is much faster than StepStatPrinter!"
)
self
.
_names
=
names
def
_extra_fetches
(
self
):
return
self
.
_names
def
_trigger_step
(
self
,
*
args
):
for
n
,
v
in
zip
(
self
.
_names
,
args
):
logger
.
info
(
"{}: {}"
.
format
(
n
,
v
))
class
SummaryMovingAverage
(
Callback
):
""" Maintain the moving average of the tensors added by :func:`summary.add_moving_summary`
in every step, and summarize them.
"""
def
_setup_graph
(
self
):
self
.
ema_op
=
summary_moving_average
()
def
_extra_fetches
(
self
):
return
[
self
.
ema_op
]
tensorpack/train/base.py
View file @
4cd01111
...
@@ -40,6 +40,7 @@ class Trainer(object):
...
@@ -40,6 +40,7 @@ class Trainer(object):
model (ModelDesc)
model (ModelDesc)
sess (tf.Session): the current session in use.
sess (tf.Session): the current session in use.
coord (tf.train.Coordinator)
coord (tf.train.Coordinator)
extra_fetches (list): list of tensors/ops to fetch by :meth:`run_step`.
"""
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -133,7 +134,7 @@ class Trainer(object):
...
@@ -133,7 +134,7 @@ class Trainer(object):
# some final operations that might modify the graph
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks ..."
)
logger
.
info
(
"Setup callbacks ..."
)
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_
extra_fetches
=
self
.
config
.
callbacks
.
extra_fetches
()
self
.
extra_fetches
=
self
.
config
.
callbacks
.
extra_fetches
()
if
not
hasattr
(
logger
,
'LOG_DIR'
):
if
not
hasattr
(
logger
,
'LOG_DIR'
):
raise
RuntimeError
(
"logger directory wasn't set!"
)
raise
RuntimeError
(
"logger directory wasn't set!"
)
...
@@ -175,8 +176,9 @@ class Trainer(object):
...
@@ -175,8 +176,9 @@ class Trainer(object):
**
get_tqdm_kwargs
(
leave
=
True
)):
**
get_tqdm_kwargs
(
leave
=
True
)):
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
return
return
self
.
run_step
()
# implemented by subclass
fetch_data
=
self
.
run_step
()
# implemented by subclass
callbacks
.
trigger_step
()
# not useful?
if
fetch_data
:
callbacks
.
trigger_step
(
*
fetch_data
)
# trigger epoch outside the timing region.
# trigger epoch outside the timing region.
self
.
trigger_epoch
()
self
.
trigger_epoch
()
except
StopTraining
:
except
StopTraining
:
...
...
tensorpack/train/config.py
View file @
4cd01111
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..callbacks
.group
import
Callbacks
from
..callbacks
import
Callbacks
,
SummaryMovingAverage
,
StatPrinter
from
..dataflow.base
import
DataFlow
from
..dataflow.base
import
DataFlow
from
..models
import
ModelDesc
from
..models
import
ModelDesc
from
..utils
import
logger
from
..utils
import
logger
...
@@ -21,7 +21,8 @@ class TrainConfig(object):
...
@@ -21,7 +21,8 @@ class TrainConfig(object):
"""
"""
def
__init__
(
self
,
dataflow
=
None
,
data
=
None
,
def
__init__
(
self
,
dataflow
=
None
,
data
=
None
,
model
=
None
,
optimizer
=
None
,
callbacks
=
None
,
model
=
None
,
optimizer
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
session_config
=
get_default_sess_config
(),
session_config
=
get_default_sess_config
(),
session_init
=
None
,
session_init
=
None
,
starting_epoch
=
1
,
step_per_epoch
=
None
,
max_epoch
=
99999
,
starting_epoch
=
1
,
step_per_epoch
=
None
,
max_epoch
=
99999
,
...
@@ -34,7 +35,11 @@ class TrainConfig(object):
...
@@ -34,7 +35,11 @@ class TrainConfig(object):
or ``data`` has to be present.
or ``data`` has to be present.
model (ModelDesc): the model to train.
model (ModelDesc): the model to train.
optimizer (tf.train.Optimizer): the optimizer for trainig.
optimizer (tf.train.Optimizer): the optimizer for trainig.
callbacks (Callbacks): the callbacks to perform during training.
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults. The defaults are
``[SummaryMovingAverage(), StatPrinter()]``. The list of
callbacks that will be used in the end is ``callbacks + extra_callbacks``.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_config (tf.ConfigProto): the config used to instantiate the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
session_init (SessionInit): how to initialize variables of a session. Defaults to a new session.
starting_epoch (int): The index of the first epoch.
starting_epoch (int): The index of the first epoch.
...
@@ -50,6 +55,7 @@ class TrainConfig(object):
...
@@ -50,6 +55,7 @@ class TrainConfig(object):
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
# process data
if
'dataset'
in
kwargs
:
if
'dataset'
in
kwargs
:
dataflow
=
kwargs
.
pop
(
'dataset'
)
dataflow
=
kwargs
.
pop
(
'dataset'
)
logger
.
warn
(
"[Deprecated] TrainConfig.dataset has been deprecated. Use TrainConfig.dataflow instead."
)
logger
.
warn
(
"[Deprecated] TrainConfig.dataset has been deprecated. Use TrainConfig.dataflow instead."
)
...
@@ -65,8 +71,20 @@ class TrainConfig(object):
...
@@ -65,8 +71,20 @@ class TrainConfig(object):
self
.
optimizer
=
optimizer
self
.
optimizer
=
optimizer
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
assert_type
(
self
.
optimizer
,
tf
.
train
.
Optimizer
)
self
.
callbacks
=
callbacks
assert_type
(
self
.
callbacks
,
Callbacks
)
if
isinstance
(
callbacks
,
Callbacks
):
# keep quiet now because I haven't determined the final API yet.
# logger.warn("[Deprecated] API of TrainConfig(callbacks=) has changed!")
# logger.warn("[Deprecated] Please change the option 'callbacks=' to a list of "
# "callbacks without StatPrinter().")
callbacks
=
callbacks
.
cbs
[:
-
1
]
# the last one is StatPrinter()
assert_type
(
callbacks
,
list
)
if
extra_callbacks
is
None
:
extra_callbacks
=
[
SummaryMovingAverage
(),
StatPrinter
()]
self
.
callbacks
=
callbacks
+
extra_callbacks
assert_type
(
self
.
callbacks
,
list
)
self
.
callbacks
=
Callbacks
(
self
.
callbacks
)
self
.
model
=
model
self
.
model
=
model
assert_type
(
self
.
model
,
ModelDesc
)
assert_type
(
self
.
model
,
ModelDesc
)
...
@@ -102,12 +120,6 @@ class TrainConfig(object):
...
@@ -102,12 +120,6 @@ class TrainConfig(object):
if
isinstance
(
self
.
predict_tower
,
int
):
if
isinstance
(
self
.
predict_tower
,
int
):
self
.
predict_tower
=
[
self
.
predict_tower
]
self
.
predict_tower
=
[
self
.
predict_tower
]
# TODO deprecated @Jan20
self
.
extra_threads_procs
=
kwargs
.
pop
(
'extra_threads_procs'
,
[])
if
self
.
extra_threads_procs
:
logger
.
warn
(
"[DEPRECATED] use the Callback StartProcOrThread instead of _extra_threads_procs"
)
from
..callbacks.concurrency
import
StartProcOrThread
self
.
callbacks
.
append
(
StartProcOrThread
(
self
.
extra_threads_procs
))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
def
set_tower
(
self
,
nr_tower
=
None
,
tower
=
None
):
...
...
tensorpack/train/feedfree.py
View file @
4cd01111
...
@@ -9,7 +9,6 @@ from ..utils import logger
...
@@ -9,7 +9,6 @@ from ..utils import logger
from
..tfutils
import
get_global_step_var
from
..tfutils
import
get_global_step_var
from
..tfutils.tower
import
TowerContext
from
..tfutils.tower
import
TowerContext
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.summary
import
summary_moving_average
from
.input_data
import
QueueInput
,
FeedfreeInput
from
.input_data
import
QueueInput
,
FeedfreeInput
from
.base
import
Trainer
from
.base
import
Trainer
...
@@ -55,7 +54,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -55,7 +54,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
def
run_step
(
self
):
def
run_step
(
self
):
""" Simply run ``self.train_op``, which minimizes the cost."""
""" Simply run ``self.train_op``, which minimizes the cost."""
self
.
sess
.
run
(
self
.
train_op
)
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
extra_fetches
)
return
ret
[
1
:]
# if not hasattr(self, 'cnt'):
# if not hasattr(self, 'cnt'):
# self.cnt = 0
# self.cnt = 0
# else:
# else:
...
@@ -101,9 +101,8 @@ class SimpleFeedfreeTrainer(
...
@@ -101,9 +101,8 @@ class SimpleFeedfreeTrainer(
cost
,
grads
=
self
.
_get_cost_and_grad
()
cost
,
grads
=
self
.
_get_cost_and_grad
()
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
grads
,
get_global_step_var
(),
name
=
'min_op'
)
summary_moving_average
(),
name
=
'train_op'
)
# skip training
# skip training
# self.train_op = tf.group(*self.dequed_inputs)
# self.train_op = tf.group(*self.dequed_inputs)
...
...
tensorpack/train/multigpu.py
View file @
4cd01111
...
@@ -11,7 +11,6 @@ from six.moves import zip, range
...
@@ -11,7 +11,6 @@ from six.moves import zip, range
from
..utils
import
logger
from
..utils
import
logger
from
..utils.naming
import
SUMMARY_BACKUP_KEYS
from
..utils.naming
import
SUMMARY_BACKUP_KEYS
from
..utils.concurrency
import
LoopThread
from
..utils.concurrency
import
LoopThread
from
..tfutils.summary
import
summary_moving_average
from
..tfutils
import
(
backup_collection
,
restore_collection
,
from
..tfutils
import
(
backup_collection
,
restore_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
from
..tfutils.gradproc
import
apply_grad_processors
,
ScaleGradient
...
@@ -113,13 +112,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -113,13 +112,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
self
.
train_op
=
tf
.
group
(
grads
,
get_global_step_var
(),
name
=
'min_op'
)
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
def
run_step
(
self
):
self
.
sess
.
run
(
self
.
train_op
)
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
...
@@ -169,10 +163,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -169,10 +163,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list
=
[
apply_grad_processors
(
g
,
gradprocs
)
for
g
in
grad_list
]
grad_list
=
[
apply_grad_processors
(
g
,
gradprocs
)
for
g
in
grad_list
]
# use grad from the first tower for iteration in main thread
# use grad from the first tower for iteration in main thread
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
0
],
get_global_step_var
(),
name
=
'min_op'
)
grad_list
[
0
],
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
self
.
_start_async_threads
(
grad_list
)
self
.
_start_async_threads
(
grad_list
)
...
@@ -199,7 +191,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -199,7 +191,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
for
th
in
self
.
training_threads
:
# resume all threads
for
th
in
self
.
training_threads
:
# resume all threads
th
.
resume
()
th
.
resume
()
next
(
self
.
async_step_counter
)
next
(
self
.
async_step_counter
)
self
.
sess
.
run
(
self
.
train_op
)
return
super
(
AsyncMultiGPUTrainer
,
self
)
.
run_step
(
)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
self
.
async_running
=
False
self
.
async_running
=
False
...
...
tensorpack/train/trainer.py
View file @
4cd01111
...
@@ -9,7 +9,6 @@ from .base import Trainer
...
@@ -9,7 +9,6 @@ from .base import Trainer
from
..utils
import
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..utils
import
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
get_global_step_var
,
TowerContext
)
get_global_step_var
,
TowerContext
)
from
..tfutils.summary
import
summary_moving_average
from
..predict
import
OnlinePredictor
,
build_prediction_graph
from
..predict
import
OnlinePredictor
,
build_prediction_graph
from
..tfutils.gradproc
import
apply_grad_processors
from
..tfutils.gradproc
import
apply_grad_processors
from
.input_data
import
FeedInput
from
.input_data
import
FeedInput
...
@@ -43,10 +42,10 @@ class PredictorFactory(object):
...
@@ -43,10 +42,10 @@ class PredictorFactory(object):
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
def
_build_predict_tower
(
self
):
def
_build_predict_tower
(
self
):
tf
.
get_variable_scope
()
.
reuse_variables
()
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
# build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with
tf
.
name_scope
(
None
),
\
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
def
fn
(
_
):
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_input_vars
())
self
.
model
.
build_graph
(
self
.
model
.
get_input_vars
())
build_prediction_graph
(
fn
,
self
.
towers
)
build_prediction_graph
(
fn
,
self
.
towers
)
...
@@ -73,7 +72,9 @@ class SimpleTrainer(Trainer):
...
@@ -73,7 +72,9 @@ class SimpleTrainer(Trainer):
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_method
.
next_feed
()
feed
=
self
.
_input_method
.
next_feed
()
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
ret
=
self
.
sess
.
run
([
self
.
train_op
]
+
self
.
extra_fetches
,
feed_dict
=
feed
)
return
ret
[
1
:]
def
_setup
(
self
):
def
_setup
(
self
):
self
.
_input_method
.
_setup
(
self
)
self
.
_input_method
.
_setup
(
self
)
...
@@ -87,9 +88,8 @@ class SimpleTrainer(Trainer):
...
@@ -87,9 +88,8 @@ class SimpleTrainer(Trainer):
grads
=
apply_grad_processors
(
grads
,
grads
=
apply_grad_processors
(
grads
,
self
.
model
.
get_gradient_processor
())
self
.
model
.
get_gradient_processor
())
self
.
train_op
=
tf
.
group
(
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
grads
,
get_global_step_var
(),
name
=
'min_op'
)
summary_moving_average
(),
name
=
'train_op'
)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
if
self
.
summary_op
is
not
None
:
if
self
.
summary_op
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment