Commit 6221d17d authored by Yuxin Wu's avatar Yuxin Wu

allow a trainer without get_predict_func

parent 608ad4a9
......@@ -101,6 +101,7 @@ def add_moving_summary(v, *args):
assert x.get_shape().ndims == 0
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
@memoized
def summary_moving_average(tensors=None):
"""
Create a MovingAverage op and summary for tensors
......
......@@ -59,10 +59,9 @@ class Trainer(object):
""" run an iteration"""
pass
@abstractmethod
def get_predict_func(self, input_names, output_names):
""" return a online predictor"""
pass
raise NotImplementedError()
def get_predict_funcs(self, input_names, output_names, n):
""" return n predictor functions.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment