Commit 243e957f authored by Yuxin Wu's avatar Yuxin Wu

[breaking] deprecate callbacks=Callbacks() but use a list.

parent c003b1c1
...@@ -207,8 +207,8 @@ def get_config(): ...@@ -207,8 +207,8 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataflow, dataflow=dataflow,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(80, 0.0003), (120, 0.0001)]), ScheduledHyperParamSetter('learning_rate', [(80, 0.0003), (120, 0.0001)]),
ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]), ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
ScheduledHyperParamSetter('explore_factor', ScheduledHyperParamSetter('explore_factor',
...@@ -216,7 +216,7 @@ def get_config(): ...@@ -216,7 +216,7 @@ def get_config():
master, master,
StartProcOrThread(master), StartProcOrThread(master),
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2), PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['logits']), 2),
]), ],
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=M, model=M,
step_per_epoch=STEP_PER_EPOCH, step_per_epoch=STEP_PER_EPOCH,
......
...@@ -96,14 +96,14 @@ def get_config(ds_train, ds_test): ...@@ -96,14 +96,14 @@ def get_config(ds_train, ds_test):
return TrainConfig( return TrainConfig(
dataflow=ds_train, dataflow=ds_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
StatMonitorParamSetter('learning_rate', 'error', StatMonitorParamSetter('learning_rate', 'error',
lambda x: x * 0.2, 0, 5), lambda x: x * 0.2, 0, 5),
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
PeriodicCallback( PeriodicCallback(
InferenceRunner(ds_test, [ScalarStats('error')]), 2), InferenceRunner(ds_test, [ScalarStats('error')]), 2),
]), ],
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=70, max_epoch=70,
......
...@@ -110,10 +110,10 @@ def get_config(): ...@@ -110,10 +110,10 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=ds, dataflow=ds,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)]) ScheduledHyperParamSetter('learning_rate', [(25, 2e-4)])
]), ],
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=50, max_epoch=50,
......
...@@ -177,8 +177,8 @@ def get_config(): ...@@ -177,8 +177,8 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]), [(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
RunOp(lambda: M.update_target_param()), RunOp(lambda: M.update_target_param()),
...@@ -186,7 +186,7 @@ def get_config(): ...@@ -186,7 +186,7 @@ def get_config():
PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue']), 3), PeriodicCallback(Evaluator(EVAL_EPISODE, ['state'], ['Qvalue']), 3),
# HumanHyperParamSetter('learning_rate', 'hyper.txt'), # HumanHyperParamSetter('learning_rate', 'hyper.txt'),
# HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'), # HumanHyperParamSetter(ObjAttrParam(dataset_train, 'exploration'), 'hyper.txt'),
]), ],
# save memory for multiprocess evaluator # save memory for multiprocess evaluator
session_config=get_default_sess_config(0.6), session_config=get_default_sess_config(0.6),
model=M, model=M,
......
...@@ -236,8 +236,8 @@ def get_config(): ...@@ -236,8 +236,8 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=data_train, dataflow=data_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
# HumanHyperParamSetter('learning_rate'), # HumanHyperParamSetter('learning_rate'),
ScheduledHyperParamSetter( ScheduledHyperParamSetter(
'learning_rate', [(56, 2e-5), (64, 4e-6)]), 'learning_rate', [(56, 2e-5), (64, 4e-6)]),
...@@ -245,7 +245,7 @@ def get_config(): ...@@ -245,7 +245,7 @@ def get_config():
[ScalarStats('cost'), [ScalarStats('cost'),
ClassificationError('wrong-top1', 'val-error-top1'), ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]) ClassificationError('wrong-top5', 'val-error-top5')])
]), ],
model=Model(), model=Model(),
step_per_epoch=10000, step_per_epoch=10000,
max_epoch=100, max_epoch=100,
......
...@@ -163,12 +163,11 @@ def get_config(): ...@@ -163,12 +163,11 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=data_train, dataflow=data_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-5),
callbacks=Callbacks([ callbacks=[
StatPrinter(),
ModelSaver(), ModelSaver(),
InferenceRunner(data_test, InferenceRunner(data_test,
[ScalarStats('cost'), ClassificationError()]) [ScalarStats('cost'), ClassificationError()])
]), ],
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=200, max_epoch=200,
......
...@@ -109,9 +109,7 @@ def get_config(): ...@@ -109,9 +109,7 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset, dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([ callbacks=[ModelSaver()],
StatPrinter(), ModelSaver(),
]),
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=Model(), model=Model(),
step_per_epoch=300, step_per_epoch=300,
......
...@@ -102,7 +102,6 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -102,7 +102,6 @@ class GANTrainer(FeedfreeTrainerBase):
class RandomZData(DataFlow): class RandomZData(DataFlow):
def __init__(self, shape): def __init__(self, shape):
super(RandomZData, self).__init__() super(RandomZData, self).__init__()
self.shape = shape self.shape = shape
......
...@@ -168,10 +168,10 @@ def get_config(): ...@@ -168,10 +168,10 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset, dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-3),
callbacks=Callbacks([ callbacks=[
StatPrinter(), PeriodicCallback(ModelSaver(), 3), PeriodicCallback(ModelSaver(), 3),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
]), ],
model=Model(), model=Model(),
step_per_epoch=dataset.size(), step_per_epoch=dataset.size(),
max_epoch=300, max_epoch=300,
......
...@@ -158,9 +158,7 @@ def get_config(): ...@@ -158,9 +158,7 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset, dataflow=dataset,
optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6), optimizer=tf.train.AdamOptimizer(lr, beta1=0.5, epsilon=1e-6),
callbacks=Callbacks([ callbacks=[ModelSaver()],
StatPrinter(), ModelSaver(),
]),
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=Model(), model=Model(),
step_per_epoch=500, step_per_epoch=500,
......
...@@ -173,13 +173,13 @@ def get_config(): ...@@ -173,13 +173,13 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(30, 6e-6), (45, 1e-6), (60, 8e-7)]), ScheduledHyperParamSetter('learning_rate', [(30, 6e-6), (45, 1e-6), (60, 8e-7)]),
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
InferenceRunner(dataset_val, InferenceRunner(dataset_val,
BinaryClassificationStats('prediction', 'edgemap4d')) BinaryClassificationStats('prediction', 'edgemap4d'))
]), ],
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=100, max_epoch=100,
......
...@@ -160,17 +160,16 @@ def get_config(): ...@@ -160,17 +160,16 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9), optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
InferenceRunner(dataset_val, [ InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-top1-error'), ClassificationError('wrong-top1', 'val-top1-error'),
ClassificationError('wrong-top5', 'val-top5-error')]), ClassificationError('wrong-top5', 'val-top5-error')]),
# HumanHyperParamSetter('learning_rate', 'hyper-googlenet.txt')
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(8, 0.03), (14, 0.02), (17, 5e-3), [(8, 0.03), (14, 0.02), (17, 5e-3),
(19, 3e-3), (24, 1e-3), (26, 2e-4), (19, 3e-3), (24, 1e-3), (26, 2e-4),
(30, 5e-5)]) (30, 5e-5)])
]), ],
session_config=get_default_sess_config(0.99), session_config=get_default_sess_config(0.99),
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
......
...@@ -268,8 +268,8 @@ def get_config(): ...@@ -268,8 +268,8 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
InferenceRunner(dataset_val, [ InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'), ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]), ClassificationError('wrong-top5', 'val-error-top5')]),
...@@ -278,7 +278,7 @@ def get_config(): ...@@ -278,7 +278,7 @@ def get_config():
(17, 0.003), (22, 1e-3), (36, 2e-4), (17, 0.003), (22, 1e-3), (36, 2e-4),
(41, 8e-5), (48, 1e-5), (53, 2e-6)]), (41, 8e-5), (48, 1e-5), (53, 2e-6)]),
HumanHyperParamSetter('learning_rate') HumanHyperParamSetter('learning_rate')
]), ],
session_config=get_default_sess_config(0.9), session_config=get_default_sess_config(0.9),
model=Model(), model=Model(),
step_per_epoch=5000, step_per_epoch=5000,
......
...@@ -126,8 +126,8 @@ def get_config(): ...@@ -126,8 +126,8 @@ def get_config():
data=train_data, data=train_data,
model=M, model=M,
optimizer=tf.train.GradientDescentOptimizer(lr), optimizer=tf.train.GradientDescentOptimizer(lr),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
HyperParamSetterWithFunc( HyperParamSetterWithFunc(
'learning_rate', 'learning_rate',
lambda e, x: x * 0.80 if e > 6 else x), lambda e, x: x * 0.80 if e > 6 else x),
...@@ -139,7 +139,7 @@ def get_config(): ...@@ -139,7 +139,7 @@ def get_config():
'validation_perplexity', 'validation_perplexity',
np.exp(self.trainer.stat_holder.get_stat_now('validation_cost') / SEQ_LEN))), np.exp(self.trainer.stat_holder.get_stat_now('validation_cost') / SEQ_LEN))),
RunOp(lambda: M.reset_lstm_state()), RunOp(lambda: M.reset_lstm_state()),
]), ],
max_epoch=70, max_epoch=70,
) )
......
...@@ -141,13 +141,13 @@ def get_config(): ...@@ -141,13 +141,13 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9), optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
InferenceRunner(dataset_test, InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError()]), [ScalarStats('cost'), ClassificationError()]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]), ],
model=Model(n=NUM_UNITS), model=Model(n=NUM_UNITS),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=400, max_epoch=400,
......
...@@ -189,15 +189,15 @@ def get_config(): ...@@ -189,15 +189,15 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True), optimizer=tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
InferenceRunner(dataset_val, [ InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'), ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]), ClassificationError('wrong-top5', 'val-error-top5')]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]), [(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
]), ],
model=Model(), model=Model(),
step_per_epoch=5000, step_per_epoch=5000,
max_epoch=110, max_epoch=110,
......
...@@ -70,14 +70,13 @@ def get_config(): ...@@ -70,14 +70,13 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
optimizer=tf.train.MomentumOptimizer(lr, 0.9), optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([ callbacks=[
StatPrinter(),
ModelSaver(), ModelSaver(),
InferenceRunner(dataset_test, InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError()]), [ScalarStats('cost'), ClassificationError()]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)]) [(1, 0.1), (20, 0.01), (28, 0.001), (50, 0.0001)])
]), ],
model=Model(n=18), model=Model(n=18),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=500, max_epoch=500,
......
...@@ -141,11 +141,10 @@ def get_config(model): ...@@ -141,11 +141,10 @@ def get_config(model):
dataflow=dataset, dataflow=dataset,
model=model(), model=model(),
optimizer=tf.train.GradientDescentOptimizer(lr), optimizer=tf.train.GradientDescentOptimizer(lr),
callbacks=Callbacks([ callbacks=[
StatPrinter(),
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(10, 1e-5), (20, 1e-6)]) ScheduledHyperParamSetter('learning_rate', [(10, 1e-5), (20, 1e-6)])
]), ],
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=20, max_epoch=20,
) )
......
...@@ -155,12 +155,12 @@ def get_config(): ...@@ -155,12 +155,12 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
InferenceRunner(dataset_test, InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError()]), [ScalarStats('cost'), ClassificationError()]),
ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)]) ScheduledHyperParamSetter('learning_rate', [(200, 1e-4)])
]), ],
session_config=get_default_sess_config(0.5), session_config=get_default_sess_config(0.5),
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
......
...@@ -122,12 +122,12 @@ def get_config(cifar_classnum): ...@@ -122,12 +122,12 @@ def get_config(cifar_classnum):
return TrainConfig( return TrainConfig(
dataflow=dataset_train, dataflow=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
InferenceRunner(dataset_test, ClassificationError()), InferenceRunner(dataset_test, ClassificationError()),
StatMonitorParamSetter('learning_rate', 'val_error', lr_func, StatMonitorParamSetter('learning_rate', 'val_error', lr_func,
threshold=0.001, last_k=10), threshold=0.001, last_k=10),
]), ],
session_config=sess_config, session_config=sess_config,
model=Model(cifar_classnum), model=Model(cifar_classnum),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
......
...@@ -140,14 +140,13 @@ def get_config(): ...@@ -140,14 +140,13 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=dataset_train, # the DataFlow instance for training dataflow=dataset_train, # the DataFlow instance for training
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=[
StatPrinter(), # print statistics in terminal after every epoch
ModelSaver(), # save the model after every epoch ModelSaver(), # save the model after every epoch
InferenceRunner( # run inference(for validation) after every epoch InferenceRunner( # run inference(for validation) after every epoch
dataset_test, # the DataFlow instance used for validation dataset_test, # the DataFlow instance used for validation
# Calculate both the cost and the error for this DataFlow # Calculate both the cost and the error for this DataFlow
[ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]), [ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]),
]), ],
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=100, max_epoch=100,
......
...@@ -101,11 +101,11 @@ def get_config(): ...@@ -101,11 +101,11 @@ def get_config():
return TrainConfig( return TrainConfig(
dataflow=data_train, dataflow=data_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=[
StatPrinter(), ModelSaver(), ModelSaver(),
InferenceRunner(data_test, InferenceRunner(data_test,
[ScalarStats('cost'), ClassificationError()]) [ScalarStats('cost'), ClassificationError()])
]), ],
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=350, max_epoch=350,
......
...@@ -59,13 +59,20 @@ class Callbacks(Callback): ...@@ -59,13 +59,20 @@ class Callbacks(Callback):
for cb in cbs: for cb in cbs:
assert isinstance(cb, Callback), cb.__class__ assert isinstance(cb, Callback), cb.__class__
# move "StatPrinter" to the last # move "StatPrinter" to the last
for cb in cbs: # TODO don't need to manually move in the future.
found = False
for idx, cb in enumerate(cbs):
if isinstance(cb, StatPrinter): if isinstance(cb, StatPrinter):
if found:
raise ValueError("Callbacks cannot contain two StatPrinter!")
sp = cb sp = cb
cbs.remove(sp) cbs.remove(sp)
cbs.append(sp) cbs.append(sp)
break if idx != len(cbs) - 1:
else: logger.warn("StatPrinter should appear as the last element of callbacks! "
"This is now fixed automatically, but may not work in the future.")
found = True
if not found:
raise ValueError("Callbacks must contain StatPrinter for stat and writer to work properly!") raise ValueError("Callbacks must contain StatPrinter for stat and writer to work properly!")
self.cbs = cbs self.cbs = cbs
......
...@@ -79,9 +79,9 @@ class TrainConfig(object): ...@@ -79,9 +79,9 @@ class TrainConfig(object):
if isinstance(callbacks, Callbacks): if isinstance(callbacks, Callbacks):
# keep quiet now because I haven't determined the final API yet. # keep quiet now because I haven't determined the final API yet.
# logger.warn("[Deprecated] API of TrainConfig(callbacks=) has changed!") logger.warn("[Deprecated] API of TrainConfig(callbacks=) has changed!")
# logger.warn("[Deprecated] Please change the option 'callbacks=' to a list of " logger.warn("[Deprecated] Please change the argument 'callbacks=' to a *list* of "
# "callbacks without StatPrinter().") "callbacks without StatPrinter().")
callbacks = callbacks.cbs[:-1] # the last one is StatPrinter() callbacks = callbacks.cbs[:-1] # the last one is StatPrinter()
assert_type(callbacks, list) assert_type(callbacks, list)
if extra_callbacks is None: if extra_callbacks is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment