Commit f3644ce9 authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'dev'

parents 05f7ba8f 0f2eaeea
...@@ -77,10 +77,6 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -77,10 +77,6 @@ class GANTrainer(FeedfreeTrainerBase):
self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op') self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
self.train_op = self.d_min self.train_op = self.d_min
def run_step(self):
ret = self.sess.run([self.train_op] + self.get_extra_fetches())
return ret[1:]
class RandomZData(DataFlow): class RandomZData(DataFlow):
def __init__(self, shape): def __init__(self, shape):
......
...@@ -88,9 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase): ...@@ -88,9 +88,8 @@ class WGANTrainer(FeedfreeTrainerBase):
def run_step(self): def run_step(self):
for k in range(5): for k in range(5):
self.sess.run(self.d_min) self.hooked_sess.run(self.d_min)
ret = self.sess.run([self.g_min] + self.get_extra_fetches()) self.hooked_sess.run(self.g_min)
return ret[1:]
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -54,42 +54,44 @@ class Callback(object): ...@@ -54,42 +54,44 @@ class Callback(object):
def _before_train(self): def _before_train(self):
pass pass
def trigger_step(self, *args): def trigger_step(self):
""" """
Callback to be triggered after every step (every backpropagation). Callback to be triggered after every run_step.
Args:
args: a list of values corresponding to :meth:`extra_fetches`.
Could be useful to apply some tricks on parameters (clipping, low-rank, etc)
""" """
self._trigger_step(*args) self._trigger_step()
def _trigger_step(self, *args): def _trigger_step(self):
pass pass
def extra_fetches(self): def after_run(self, run_context, run_values):
""" self._after_run(run_context, run_values)
Returns:
list: a list of elements to be fetched in every step and
passed to :meth:`trigger_step`. Elements can be
Operations/Tensors, or names of Operations/Tensors.
This function will be called only after the graph is finalized. def _after_run(self, run_context, run_values):
pass
This function should be a pure function (i.e. no side-effect when called) def before_run(self, ctx):
"""
Same as ``tf.train.SessionRunHook.before_run``.
""" """
fetches = self._extra_fetches() fetches = self._before_run(ctx)
if fetches is None:
return None
if isinstance(fetches, tf.train.SessionRunArgs):
return fetches
# also support list of names
assert isinstance(fetches, list), fetches
ret = [] ret = []
for f in fetches: for f in fetches:
if isinstance(f, (tf.Tensor, tf.Operation)): if isinstance(f, (tf.Tensor, tf.Operation)):
ret.append(f) ret.append(f)
else: else:
# warn about speed
ret.append(get_op_or_tensor_by_name(f)) ret.append(get_op_or_tensor_by_name(f))
return ret return tf.train.SessionRunArgs(fetches=ret)
def _extra_fetches(self): def _before_run(self, ctx):
return [] return None
def trigger_epoch(self): def trigger_epoch(self):
""" """
...@@ -113,11 +115,6 @@ class Callback(object): ...@@ -113,11 +115,6 @@ class Callback(object):
def epoch_num(self): def epoch_num(self):
return self.trainer.epoch_num return self.trainer.epoch_num
@property
def local_step(self):
# inside trainer, we're still in the 'local_step' loop, so the number is off by 1
return self.trainer.local_step + 1
@property @property
def global_step(self): def global_step(self):
return self.trainer.global_step return self.trainer.global_step
...@@ -177,12 +174,18 @@ class ProxyCallback(Callback): ...@@ -177,12 +174,18 @@ class ProxyCallback(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
self.cb.trigger_epoch() self.cb.trigger_epoch()
def _trigger_step(self, *args): def _trigger_step(self):
self.cb.trigger_step(*args) self.cb.trigger_step()
def _after_train(self): def _after_train(self):
self.cb.after_train() self.cb.after_train()
def _before_run(self, ctx):
self.cb._before_run(ctx)
def _after_run(self, ctx, run_values):
self.cb._after_run(ctx, run_values)
def __str__(self): def __str__(self):
return "Proxy-" + str(self.cb) return "Proxy-" + str(self.cb)
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from collections import defaultdict
import time import time
import traceback import traceback
...@@ -15,8 +14,18 @@ from ..utils import logger ...@@ -15,8 +14,18 @@ from ..utils import logger
__all__ = ['Callbacks'] __all__ = ['Callbacks']
class CallbackTimeLogger(object): class CallbackHook(tf.train.SessionRunHook):
def __init__(self, cb):
self.cb = cb
def before_run(self, ctx):
return self.cb.before_run(ctx)
def after_run(self, ctx, vals):
self.cb.after_run(ctx, vals)
class CallbackTimeLogger(object):
def __init__(self): def __init__(self):
self.times = [] self.times = []
self.tot = 0 self.tot = 0
...@@ -71,7 +80,6 @@ class Callbacks(Callback): ...@@ -71,7 +80,6 @@ class Callbacks(Callback):
break break
self.cbs = cbs self.cbs = cbs
self._extra_fetches_cache = None
def _setup_graph(self): def _setup_graph(self):
with tf.name_scope(None): with tf.name_scope(None):
...@@ -90,30 +98,12 @@ class Callbacks(Callback): ...@@ -90,30 +98,12 @@ class Callbacks(Callback):
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
def _extra_fetches(self): def get_hooks(self):
if self._extra_fetches_cache is not None: return [CallbackHook(cb) for cb in self.cbs]
return self._extra_fetches_cache
# TODO use dispatch mechanism to avoid duplication def trigger_step(self):
self._cbid_to_fetchid = defaultdict(list) for cb in self.cbs:
ret = []
for idx, cb in enumerate(self.cbs):
fetch = cb.extra_fetches()
if len(fetch) == 0:
continue
for f in fetch:
ret.append(f)
self._cbid_to_fetchid[idx].append(len(ret) - 1)
self._extra_fetches_cache = ret
return ret
def _trigger_step(self, *args):
for idx, cb in enumerate(self.cbs):
fid = self._cbid_to_fetchid[idx]
if len(fid) == 0:
cb.trigger_step() cb.trigger_step()
else:
data = [args[k] for k in fid]
cb.trigger_step(*data)
def _trigger_epoch(self): def _trigger_epoch(self):
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
......
...@@ -10,9 +10,10 @@ from six.moves import zip ...@@ -10,9 +10,10 @@ from six.moves import zip
import tqdm import tqdm
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..utils.naming import (GLOBAL_STEP_INCR_OP_NAME, from ..utils.naming import GLOBAL_STEP_INCR_OP_NAME
LOCAL_STEP_OP_NAME) from ..tfutils.common import (
from ..tfutils.common import get_op_tensor_name, get_global_step_var, get_global_step_value get_op_tensor_name, get_global_step_var,
get_global_step_value, get_op_or_tensor_by_name)
from .base import Callback from .base import Callback
__all__ = ['StepTensorPrinter', 'MaintainStepCounter', __all__ = ['StepTensorPrinter', 'MaintainStepCounter',
...@@ -33,10 +34,14 @@ class StepTensorPrinter(Callback): ...@@ -33,10 +34,14 @@ class StepTensorPrinter(Callback):
logger.warn("Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!") logger.warn("Using print_stat or tf.Print in the graph is much faster than StepTensorPrinter!")
self._names = names self._names = names
def _extra_fetches(self): def _before_train(self):
return self._names self._fetches = get_op_or_tensor_by_name(self._names)
def _before_run(self, _):
return self._fetches
def _trigger_step(self, *args): def _after_run(self, _, vals):
args = vals.results
assert len(args) == len(self._names), len(args) assert len(args) == len(self._names), len(args)
for n, v in zip(self._names, args): for n, v in zip(self._names, args):
logger.info("{}: {}".format(n, v)) logger.info("{}: {}".format(n, v))
...@@ -55,17 +60,24 @@ class MaintainStepCounter(Callback): ...@@ -55,17 +60,24 @@ class MaintainStepCounter(Callback):
self.gs_incr_var = tf.assign_add( self.gs_incr_var = tf.assign_add(
gs_var, 1, gs_var, 1,
name=GLOBAL_STEP_INCR_OP_NAME) name=GLOBAL_STEP_INCR_OP_NAME)
tf.mod( # tf.mod(
self.gs_incr_var, self.trainer.config.steps_per_epoch, # self.gs_incr_var, self.trainer.config.steps_per_epoch,
name=LOCAL_STEP_OP_NAME) # name=LOCAL_STEP_OP_NAME)
self._fetches = tf.train.SessionRunArgs(self.gs_incr_var)
def _before_train(self): def _before_train(self):
gs_val = get_global_step_value() gs_val = get_global_step_value()
if gs_val != 0: if gs_val != 0:
logger.info("Start training with global_step={}".format(gs_val)) logger.info("Start training with global_step={}".format(gs_val))
self._last_updated = self.trainer.local_step
def _extra_fetches(self): def _before_run(self, _):
return [self.gs_incr_var.op] # increase global_step, when trainer.local_step changed
if self.trainer.local_step != self._last_updated:
self._last_updated = self.trainer.local_step
return self._fetches
else:
return None
class ProgressBar(Callback): class ProgressBar(Callback):
...@@ -80,21 +92,34 @@ class ProgressBar(Callback): ...@@ -80,21 +92,34 @@ class ProgressBar(Callback):
self._names = [get_op_tensor_name(n)[1] for n in names] self._names = [get_op_tensor_name(n)[1] for n in names]
self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names] self._tags = [get_op_tensor_name(n)[0].split("/")[-1] for n in names]
def _extra_fetches(self):
return self._names
def _before_train(self): def _before_train(self):
self._last_updated = self.trainer.local_step
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True) self._tqdm_args = get_tqdm_kwargs(leave=True)
if len(self._names):
self._fetches = get_op_or_tensor_by_name(self._names) or None
if self._fetches:
self._fetches = tf.train.SessionRunArgs(self._fetches)
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} " self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _trigger_step(self, *args): def _before_run(self, _):
if self.local_step == 1: if self.trainer.local_step != self._last_updated:
self._last_updated = self.trainer.local_step
if self.trainer.local_step == 0:
self._bar = tqdm.trange(self._total, **self._tqdm_args) self._bar = tqdm.trange(self._total, **self._tqdm_args)
if len(self._names):
self._bar.set_postfix(zip(self._tags, args))
self._bar.update()
if self.local_step == self._total: return self._fetches
else:
return None
def _after_run(self, _, run_values):
res = run_values.results
if res:
self._bar.set_postfix(zip(self._tags, res))
def _trigger_step(self):
self._bar.update()
if self.trainer.local_step == self._total - 1:
self._bar.close() self._bar.close()
...@@ -28,5 +28,5 @@ class MovingAverageSummary(Callback): ...@@ -28,5 +28,5 @@ class MovingAverageSummary(Callback):
ops = tf.get_collection(self._collection) ops = tf.get_collection(self._collection)
self.ema_op = tf.group(*ops, name='summary_moving_averages') self.ema_op = tf.group(*ops, name='summary_moving_averages')
def _extra_fetches(self): def _before_run(self, _):
return [self.ema_op] return [self.ema_op]
...@@ -31,13 +31,15 @@ class PeriodicTrigger(ProxyCallback): ...@@ -31,13 +31,15 @@ class PeriodicTrigger(ProxyCallback):
self._step_k = every_k_steps self._step_k = every_k_steps
self._epoch_k = every_k_epochs self._epoch_k = every_k_epochs
def _trigger_step(self, *args): def _trigger_step(self):
if self._step_k is None: if self._step_k is None:
return return
if self.local_step % self._step_k == 0: # trigger_step is triggered after run_step, so
# local_step + 1 is the number of step that have finished
if (self.trainer.local_step + 1) % self._step_k == 0:
self.cb.trigger() self.cb.trigger()
def _trigger_epoch(self, *args): def _trigger_epoch(self):
if self._epoch_k is None: if self._epoch_k is None:
return return
if self.epoch_num % self._epoch_k == 0: if self.epoch_num % self._epoch_k == 0:
...@@ -62,10 +64,6 @@ class PeriodicCallback(ProxyCallback): ...@@ -62,10 +64,6 @@ class PeriodicCallback(ProxyCallback):
Args: Args:
cb(Callback): the callback to be triggered periodically cb(Callback): the callback to be triggered periodically
period(int): the period, the number of epochs for a callback to be triggered. period(int): the period, the number of epochs for a callback to be triggered.
Note:
In ``cb``, ``self.epoch_num`` will not be the true number of
epochs any more.
""" """
super(PeriodicCallback, self).__init__(cb) super(PeriodicCallback, self).__init__(cb)
self.period = int(period) self.period = int(period)
......
...@@ -141,7 +141,7 @@ def Deconv2D(x, out_shape, kernel_shape, ...@@ -141,7 +141,7 @@ def Deconv2D(x, out_shape, kernel_shape,
for k in out_shape: for k in out_shape:
if not isinstance(k, int): if not isinstance(k, int):
raise ValueError("[Deconv2D] out_shape {} is invalid!".format(k)) raise ValueError("[Deconv2D] out_shape {} is invalid!".format(k))
out_channel = out_shape[channel_axis - 1] out_channel = out_shape[channel_axis - 1] # out_shape doesn't have batch
shp3_static = shp3_dyn = out_shape shp3_static = shp3_dyn = out_shape
filter_shape = kernel_shape + [out_channel, in_channel] filter_shape = kernel_shape + [out_channel, in_channel]
......
...@@ -4,19 +4,18 @@ ...@@ -4,19 +4,18 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from six.moves import map
from ..utils.naming import ( from ..utils.naming import (
GLOBAL_STEP_VAR_NAME, GLOBAL_STEP_VAR_NAME,
GLOBAL_STEP_OP_NAME, GLOBAL_STEP_OP_NAME)
LOCAL_STEP_VAR_NAME)
from ..utils import logger
from ..utils.argtools import memoized from ..utils.argtools import memoized
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step_value', 'get_global_step_value',
'get_global_step_var', 'get_global_step_var',
'get_local_step_var', #'get_local_step_var',
'get_op_tensor_name', 'get_op_tensor_name',
'get_tensors_by_names', 'get_tensors_by_names',
...@@ -74,13 +73,13 @@ def get_global_step_value(): ...@@ -74,13 +73,13 @@ def get_global_step_value():
get_global_step_var()) get_global_step_var())
@memoized # @memoized
def get_local_step_var(): # def get_local_step_var():
try: # try:
return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME) # return tf.get_default_graph().get_tensor_by_name(LOCAL_STEP_VAR_NAME)
except KeyError: # except KeyError:
logger.warn("get_local_step_var() is only available to use in callbacks!") # logger.warn("get_local_step_var() is only available to use in callbacks!")
raise # raise
def get_op_tensor_name(name): def get_op_tensor_name(name):
...@@ -116,11 +115,24 @@ def get_tensors_by_names(names): ...@@ -116,11 +115,24 @@ def get_tensors_by_names(names):
def get_op_or_tensor_by_name(name): def get_op_or_tensor_by_name(name):
"""
Get either tf.Operation of tf.Tensor from names.
Args:
name (list[str] or str): names of operations or tensors.
"""
G = tf.get_default_graph() G = tf.get_default_graph()
if len(name) >= 3 and name[-2] == ':':
return G.get_tensor_by_name(name) def f(n):
if len(n) >= 3 and n[-2] == ':':
return G.get_tensor_by_name(n)
else:
return G.get_operation_by_name(n)
if not isinstance(name, list):
return f(name)
else: else:
return G.get_operation_by_name(name) return list(map(f, name))
def get_name_scope_name(): def get_name_scope_name():
......
...@@ -129,6 +129,7 @@ def add_moving_summary(v, *args, **kwargs): ...@@ -129,6 +129,7 @@ def add_moving_summary(v, *args, **kwargs):
decay, num_updates=get_global_step_var(), name='EMA') decay, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(v) avg_maintain_op = averager.apply(v)
for c in v: for c in v:
# TODO do this in the EMA callback?
name = re.sub('tower[p0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name + '-summary', averager.average(c)) tf.summary.scalar(name + '-summary', averager.average(c))
......
...@@ -40,8 +40,8 @@ class Trainer(object): ...@@ -40,8 +40,8 @@ class Trainer(object):
summary_writer (tf.summary.FileWriter) summary_writer (tf.summary.FileWriter)
summary_op (tf.Operation): an Op which outputs all summaries. summary_op (tf.Operation): an Op which outputs all summaries.
epoch_num (int): the current epoch number. epoch_num (int): the number of epochs that have finished.
local_step (int): the current step number (in an epoch). local_step (int): the number of steps that have finished in the current epoch.
""" """
def __init__(self, config): def __init__(self, config):
...@@ -54,7 +54,7 @@ class Trainer(object): ...@@ -54,7 +54,7 @@ class Trainer(object):
self.model = config.model self.model = config.model
self.epoch_num = self.config.starting_epoch - 1 self.epoch_num = self.config.starting_epoch - 1
self.local_step = 0 self.local_step = -1
def train(self): def train(self):
""" Start training """ """ Start training """
...@@ -65,15 +65,6 @@ class Trainer(object): ...@@ -65,15 +65,6 @@ class Trainer(object):
def run_step(self): def run_step(self):
""" Abstract method. Run one iteration. """ """ Abstract method. Run one iteration. """
def get_extra_fetches(self):
"""
Returns:
list: list of tensors/ops to fetch in each step.
This function should only get called after :meth:`setup()` has finished.
"""
return self._extra_fetches
def trigger_epoch(self): def trigger_epoch(self):
""" """
Called after each epoch. Called after each epoch.
...@@ -130,7 +121,6 @@ class Trainer(object): ...@@ -130,7 +121,6 @@ class Trainer(object):
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
self.config.callbacks.setup_graph(weakref.proxy(self)) self.config.callbacks.setup_graph(weakref.proxy(self))
self._extra_fetches = self.config.callbacks.extra_fetches()
logger.info("Setup summaries ...") logger.info("Setup summaries ...")
self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph()) self.summary_writer = tf.summary.FileWriter(logger.LOG_DIR, graph=tf.get_default_graph())
...@@ -149,8 +139,10 @@ class Trainer(object): ...@@ -149,8 +139,10 @@ class Trainer(object):
self.monitored_sess = tf.train.MonitoredSession( self.monitored_sess = tf.train.MonitoredSession(
session_creator=tf.train.ChiefSessionCreator( session_creator=tf.train.ChiefSessionCreator(
scaffold=scaffold, config=self.config.session_config), scaffold=scaffold, config=self.config.session_config),
hooks=None) hooks=self.config.callbacks.get_hooks())
self.sess = self.monitored_sess._tf_sess() self.hooked_sess = self.monitored_sess # just create an alias
self.sess = self.monitored_sess._tf_sess() # expose the underlying session also
self.config.session_init._run_init(self.sess) self.config.session_init._run_init(self.sess)
@abstractmethod @abstractmethod
...@@ -162,7 +154,7 @@ class Trainer(object): ...@@ -162,7 +154,7 @@ class Trainer(object):
try: try:
return self._starting_step + \ return self._starting_step + \
self.config.steps_per_epoch * (self.epoch_num - 1) + \ self.config.steps_per_epoch * (self.epoch_num - 1) + \
self.local_step + 1 self.local_step + 1 # +1: the ongoing step
except AttributeError: except AttributeError:
return get_global_step_value() return get_global_step_value()
...@@ -182,12 +174,8 @@ class Trainer(object): ...@@ -182,12 +174,8 @@ class Trainer(object):
for self.local_step in range(self.config.steps_per_epoch): for self.local_step in range(self.config.steps_per_epoch):
if self.monitored_sess.should_stop(): if self.monitored_sess.should_stop():
return return
fetch_data = self.run_step() # implemented by subclass self.run_step() # implemented by subclass
if fetch_data is None:
# old trainer doesn't return fetch data
callbacks.trigger_step() callbacks.trigger_step()
else:
callbacks.trigger_step(*fetch_data)
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.epoch_num, self.global_step, time.time() - start_time)) self.epoch_num, self.global_step, time.time() - start_time))
......
...@@ -46,25 +46,9 @@ class FeedfreeTrainerBase(Trainer): ...@@ -46,25 +46,9 @@ class FeedfreeTrainerBase(Trainer):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method) assert isinstance(self._input_method, FeedfreeInput), type(self._input_method)
self._input_method._setup(self) self._input_method._setup(self)
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
""" get the cost and gradient"""
self.build_train_tower()
cost = self.model.get_cost()
opt = self.config.optimizer
# GATE_NONE faster?
grads = opt.compute_gradients(
cost,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True)
return cost, grads
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``, which minimizes the cost.""" """ Simply run ``self.train_op``, which minimizes the cost."""
ret = self.sess.run([self.train_op] + self.get_extra_fetches()) self.hooked_sess.run(self.train_op)
return ret[1:]
# if not hasattr(self, 'cnt'): # if not hasattr(self, 'cnt'):
# self.cnt = 0 # self.cnt = 0
# else: # else:
...@@ -83,6 +67,21 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -83,6 +67,21 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
# import sys; sys.exit() # import sys; sys.exit()
class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
""" get the cost and gradient"""
self.build_train_tower()
cost = self.model.get_cost()
opt = self.config.optimizer
# GATE_NONE faster?
grads = opt.compute_gradients(
cost,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=True)
return cost, grads
class SimpleFeedfreeTrainer( class SimpleFeedfreeTrainer(
SingleCostFeedfreeTrainer, SingleCostFeedfreeTrainer,
MultiPredictorTowerTrainer): MultiPredictorTowerTrainer):
......
...@@ -87,9 +87,7 @@ class SimpleTrainer(Trainer): ...@@ -87,9 +87,7 @@ class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
feed = self._input_method.next_feed() feed = self._input_method.next_feed()
ret = self.sess.run([self.train_op] + self.get_extra_fetches(), self.hooked_sess.run(self.train_op, feed_dict=feed)
feed_dict=feed)
return ret[1:]
def _setup(self): def _setup(self):
self._input_method._setup(self) self._input_method._setup(self)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment