Commit 6d921e36 authored by Yuxin Wu's avatar Yuxin Wu

misc small changes

parent b8349bcf
......@@ -116,7 +116,7 @@ if __name__ == '__main__':
parser.add_argument('-d', '--depth', help='resnet depth',
type=int, default=18, choices=[18, 34, 50, 101, 152])
parser.add_argument('--eval', action='store_true')
parser.add_argument('--batch', help='total batch size. 256 gives best accuracy.',
parser.add_argument('--batch', help='total batch size. 32 per GPU gives best accuracy, higher values should be similarly good',
default=256, type=int)
parser.add_argument('--mode', choices=['resnet', 'preact', 'se'],
help='variants of resnet to use', default='resnet')
......
......@@ -40,7 +40,6 @@ class Model(ModelDesc):
image = tf.expand_dims(image, 3)
image = image * 2 - 1 # center the pixels values at zero
# The context manager `argscope` sets the default option for all the layers under
# this context. Here we use 32 channel convolution with shape 3x3
with argscope(Conv2D, kernel_shape=3, nl=tf.nn.relu, out_channel=32):
......
......@@ -173,18 +173,21 @@ class GraphProfiler(Callback):
class PeakMemoryTracker(Callback):
"""
Track peak memory in each session run, by
:mod:`tf.contrib.memory_stats`.
It can only be used for GPUs.
Track peak memory used on each GPU device, by :mod:`tf.contrib.memory_stats`.
The peak memory comes from the `MaxBytesInUse` op, which might span
multiple session.run.
See https://github.com/tensorflow/tensorflow/pull/13107.
"""
_chief_only = False
def __init__(self, devices=['/gpu:0']):
def __init__(self, devices=[0]):
"""
Args:
devices([str]): list of devices to track memory on.
devices([int] or [str]): list of GPU devices to track memory on.
"""
assert isinstance(devices, (list, tuple)), devices
devices = ['/gpu:{}'.format(x) if isinstance(x, int) else x for x in devices]
self._devices = devices
def _setup_graph(self):
......
......@@ -111,15 +111,15 @@ class _MultiProcessZMQDataFlow(DataFlow):
start_proc_mask_signal(self._procs)
def __del__(self):
if not self._reset_done:
return
if not self.context.closed:
self.socket.close(0)
self.context.destroy(0)
for x in self._procs:
x.terminate()
x.join(5)
try:
if not self._reset_done:
return
if not self.context.closed:
self.socket.close(0)
self.context.destroy(0)
for x in self._procs:
x.terminate()
x.join(5)
print("{} successfully cleaned-up.".format(type(self).__name__))
except Exception:
pass
......
......@@ -496,7 +496,7 @@ class StagingInput(FeedfreeInput):
logger.info("Pre-filling StagingArea ...")
for k in range(self.nr_stage):
self.stage_op.run()
logger.info("Put {} element(s) to StagingArea.")
logger.info("Successfully put {} element(s) to StagingArea.".format(self.nr_stage))
def _before_run(self, ctx):
# This has to happen once, right before the first iteration.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment