Commit 6221d17d authored by Yuxin Wu's avatar Yuxin Wu

allow a trainer without get_predict_func

parent 608ad4a9
...@@ -101,6 +101,7 @@ def add_moving_summary(v, *args): ...@@ -101,6 +101,7 @@ def add_moving_summary(v, *args):
assert x.get_shape().ndims == 0 assert x.get_shape().ndims == 0
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, x)
@memoized
def summary_moving_average(tensors=None): def summary_moving_average(tensors=None):
""" """
Create a MovingAverage op and summary for tensors Create a MovingAverage op and summary for tensors
......
...@@ -59,10 +59,9 @@ class Trainer(object): ...@@ -59,10 +59,9 @@ class Trainer(object):
""" run an iteration""" """ run an iteration"""
pass pass
@abstractmethod
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
""" return a online predictor""" """ return a online predictor"""
pass raise NotImplementedError()
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
""" return n predictor functions. """ return n predictor functions.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment