Commit d5f3350d authored by Yuxin Wu's avatar Yuxin Wu

update resnet example to contain #139

parent 74c80d57
...@@ -86,25 +86,31 @@ def get_config(fake=False, data_format='NCHW'): ...@@ -86,25 +86,31 @@ def get_config(fake=False, data_format='NCHW'):
if fake: if fake:
logger.info("For benchmark, batch size is fixed to 64 per tower.") logger.info("For benchmark, batch size is fixed to 64 per tower.")
dataset_train = dataset_val = FakeData( dataset_train = FakeData(
[[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8') [[64, 224, 224, 3], [64]], 1000, random=False, dtype='uint8')
callbacks = []
else: else:
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, BATCH_SIZE)) logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, BATCH_SIZE))
dataset_train = get_data('train') dataset_train = get_data('train')
dataset_val = get_data('val') dataset_val = get_data('val')
callbacks = [
return TrainConfig(
model=Model(data_format=data_format),
dataflow=dataset_train,
callbacks=[
ModelSaver(), ModelSaver(),
InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]), [(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5), (105, 1e-6)]),
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
], ]
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
if nr_tower == 1:
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
else:
callbacks.append(DataParallelInferenceRunner(
dataset_val, infs, list(range(nr_tower))))
return TrainConfig(
model=Model(data_format=data_format),
dataflow=dataset_train,
callbacks=callbacks,
steps_per_epoch=5000, steps_per_epoch=5000,
max_epoch=110, max_epoch=110,
nr_tower=nr_tower nr_tower=nr_tower
......
...@@ -78,6 +78,7 @@ class InferenceRunnerBase(Callback): ...@@ -78,6 +78,7 @@ class InferenceRunnerBase(Callback):
self._size = input.size() self._size = input.size()
except NotImplementedError: except NotImplementedError:
raise ValueError("Input used in InferenceRunner must have a size!") raise ValueError("Input used in InferenceRunner must have a size!")
logger.info("InferenceRunner will eval on an InputSource of size {}".format(self._size))
if extra_hooks is None: if extra_hooks is None:
extra_hooks = [] extra_hooks = []
......
...@@ -137,8 +137,11 @@ class BatchData(ProxyDataFlow): ...@@ -137,8 +137,11 @@ class BatchData(ProxyDataFlow):
raise raise
except: except:
logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?") logger.exception("Cannot batch data. Perhaps they are of inconsistent shape?")
import IPython as IP try:
IP.embed(config=IP.terminal.ipapp.load_default_config()) # open an ipython shell if possible
import IPython as IP; IP.embed() # noqa
except:
pass
return result return result
...@@ -679,14 +682,14 @@ class PrintData(ProxyDataFlow): ...@@ -679,14 +682,14 @@ class PrintData(ProxyDataFlow):
""" """
Dump gathered debugging information to stdout. Dump gathered debugging information to stdout.
""" """
msg = [""] label = "" if self.name is None else " (" + self.label + ")"
logger.info(colored("DataFlow Info%s:" % label, 'cyan'))
for i, dummy in enumerate(itertools.islice(self.ds.get_data(), self.num)): for i, dummy in enumerate(itertools.islice(self.ds.get_data(), self.num)):
if isinstance(dummy, list): if isinstance(dummy, list):
msg.append("datapoint %i<%i with %i components consists of" % (i, self.num, len(dummy))) msg = "datapoint %i<%i with %i components consists of\n" % (i, self.num, len(dummy))
for k, entry in enumerate(dummy): for k, entry in enumerate(dummy):
msg.append(self._analyze_input_data(entry, k)) msg += self._analyze_input_data(entry, k) + '\n'
label = "" if self.name is None else " (" + self.label + ")" print(msg)
logger.info(colored("DataFlow Info%s:" % label, 'cyan') + '\n'.join(msg))
# reset again after print # reset again after print
self.ds.reset_state() self.ds.reset_state()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment