Commit bf8acbfb authored by Yuxin Wu's avatar Yuxin Wu

lint with flake8-bugbear

parent 552c2b3b
...@@ -90,7 +90,7 @@ if __name__ == '__main__': ...@@ -90,7 +90,7 @@ if __name__ == '__main__':
op_name = e.args[0] op_name = e.args[0]
_import_external_ops(op_name) _import_external_ops(op_name)
except tf.errors.NotFoundError as e: except tf.errors.NotFoundError as e:
_import_external_ops(e.message) _import_external_ops(str(e))
else: else:
break break
......
...@@ -55,11 +55,11 @@ class HyperParam(object): ...@@ -55,11 +55,11 @@ class HyperParam(object):
class GraphVarParam(HyperParam): class GraphVarParam(HyperParam):
""" A variable in the graph (e.g. learning_rate) can be a hyperparam.""" """ A variable in the graph (e.g. learning_rate) can be a hyperparam."""
def __init__(self, name, shape=[]): def __init__(self, name, shape=()):
""" """
Args: Args:
name(str): name of the variable. name(str): name of the variable.
shape(list): shape of the variable. shape(tuple): shape of the variable.
""" """
self.name = name self.name = name
self.shape = shape self.shape = shape
......
...@@ -244,7 +244,7 @@ class GPUMemoryTracker(Callback): ...@@ -244,7 +244,7 @@ class GPUMemoryTracker(Callback):
_chief_only = False _chief_only = False
def __init__(self, devices=[0]): def __init__(self, devices=(0,)):
""" """
Args: Args:
devices([int] or [str]): list of GPU devices to track memory on. devices([int] or [str]): list of GPU devices to track memory on.
......
...@@ -80,7 +80,7 @@ class ModelSaver(Callback): ...@@ -80,7 +80,7 @@ class ModelSaver(Callback):
global_step=tf.train.get_global_step(), global_step=tf.train.get_global_step(),
write_meta_graph=False) write_meta_graph=False)
logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path) logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
except (OSError, IOError, tf.errors.PermissionDeniedError, except (IOError, tf.errors.PermissionDeniedError,
tf.errors.ResourceExhaustedError): # disk error sometimes.. just ignore it tf.errors.ResourceExhaustedError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver!") logger.exception("Exception in ModelSaver!")
......
...@@ -51,10 +51,10 @@ class ProgressBar(Callback): ...@@ -51,10 +51,10 @@ class ProgressBar(Callback):
_chief_only = False _chief_only = False
def __init__(self, names=[]): def __init__(self, names=()):
""" """
Args: Args:
names(list): list of string, the names of the tensors to monitor names(tuple[str]): the names of the tensors to monitor
on the progress bar. on the progress bar.
""" """
super(ProgressBar, self).__init__() super(ProgressBar, self).__init__()
......
...@@ -82,7 +82,7 @@ except ImportError: ...@@ -82,7 +82,7 @@ except ImportError:
if __name__ == "__main__": if __name__ == "__main__":
ds = Caltech101Silhouettes("train") ds = Caltech101Silhouettes("train")
ds.reset_state() ds.reset_state()
for (img, label) in ds: for _ in ds:
from IPython import embed from IPython import embed
embed() embed()
......
...@@ -300,7 +300,7 @@ if __name__ == '__main__': ...@@ -300,7 +300,7 @@ if __name__ == '__main__':
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False) ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False)
ds.reset_state() ds.reset_state()
for k in ds: for _ in ds:
from IPython import embed from IPython import embed
embed() embed()
break break
...@@ -132,7 +132,7 @@ class FashionMnist(Mnist): ...@@ -132,7 +132,7 @@ class FashionMnist(Mnist):
if __name__ == '__main__': if __name__ == '__main__':
ds = Mnist('train') ds = Mnist('train')
ds.reset_state() ds.reset_state()
for (img, label) in ds: for _ in ds:
from IPython import embed from IPython import embed
embed() embed()
break break
...@@ -145,7 +145,7 @@ class LMDBData(RNGDataFlow): ...@@ -145,7 +145,7 @@ class LMDBData(RNGDataFlow):
with self._guard: with self._guard:
if not self._shuffle: if not self._shuffle:
c = self._txn.cursor() c = self._txn.cursor()
while c.next(): while next(c):
k, v = c.item() k, v = c.item()
if k != b'__keys__': if k != b'__keys__':
yield [k, v] yield [k, v]
......
...@@ -50,7 +50,7 @@ def _default_repr(self): ...@@ -50,7 +50,7 @@ def _default_repr(self):
defaults[k] = argspec.kwonlydefaults[k] defaults[k] = argspec.kwonlydefaults[k]
argstr = [] argstr = []
for idx, f in enumerate(fields): for f in fields:
assert hasattr(self, f), \ assert hasattr(self, f), \
"Attribute {} in {} not found! Default __repr__ only works if " \ "Attribute {} in {} not found! Default __repr__ only works if " \
"the instance has attributes that match the constructor.".format(f, classname) "the instance has attributes that match the constructor.".format(f, classname)
......
...@@ -375,7 +375,7 @@ if __name__ == '__main__': ...@@ -375,7 +375,7 @@ if __name__ == '__main__':
draw_points(orig_image, coords) draw_points(orig_image, coords)
print(coords) print(coords)
for k in range(1): for _ in range(1):
coords = trans.apply_coords(coords) coords = trans.apply_coords(coords)
image = trans.apply_image(image) image = trans.apply_image(image)
print(coords) print(coords)
......
...@@ -570,7 +570,7 @@ class StagingInput(FeedfreeInput): ...@@ -570,7 +570,7 @@ class StagingInput(FeedfreeInput):
def _prefill(self, sess): def _prefill(self, sess):
logger.info("Pre-filling StagingArea ...") logger.info("Pre-filling StagingArea ...")
for k in range(self.nr_stage): for _ in range(self.nr_stage):
self.stage_op.run(session=sess) self.stage_op.run(session=sess)
logger.info("{} element{} put into StagingArea on each tower.".format( logger.info("{} element{} put into StagingArea on each tower.".format(
self.nr_stage, "s were" if self.nr_stage > 1 else " was")) self.nr_stage, "s were" if self.nr_stage > 1 else " was"))
......
...@@ -80,7 +80,7 @@ class CollectionGuard(object): ...@@ -80,7 +80,7 @@ class CollectionGuard(object):
original = None original = None
def __init__(self, name, check_diff, def __init__(self, name, check_diff,
freeze_keys=[], freeze_keys=(),
diff_whitelist=None): diff_whitelist=None):
""" """
Args: Args:
......
...@@ -89,7 +89,7 @@ class ModelExporter(object): ...@@ -89,7 +89,7 @@ class ModelExporter(object):
logger.info("Output graph written to {}.".format(filename)) logger.info("Output graph written to {}.".format(filename))
def export_serving(self, filename, def export_serving(self, filename,
tags=[tf.saved_model.SERVING if is_tfv2() else tf.saved_model.tag_constants.SERVING], tags=(tf.saved_model.SERVING if is_tfv2() else tf.saved_model.tag_constants.SERVING,),
signature_name='prediction_pipeline'): signature_name='prediction_pipeline'):
""" """
Converts a checkpoint and graph to a servable for TensorFlow Serving. Converts a checkpoint and graph to a servable for TensorFlow Serving.
...@@ -97,7 +97,7 @@ class ModelExporter(object): ...@@ -97,7 +97,7 @@ class ModelExporter(object):
Args: Args:
filename (str): path for export directory filename (str): path for export directory
tags (list): list of user specified tags tags (tuple): tuple of user specified tags
signature_name (str): name of signature for prediction signature_name (str): name of signature for prediction
Note: Note:
...@@ -140,7 +140,7 @@ class ModelExporter(object): ...@@ -140,7 +140,7 @@ class ModelExporter(object):
method_name=tfv1.saved_model.signature_constants.PREDICT_METHOD_NAME) method_name=tfv1.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables( builder.add_meta_graph_and_variables(
sess, tags, sess, list(tags),
signature_def_map={signature_name: prediction_signature}) signature_def_map={signature_name: prediction_signature})
builder.save() builder.save()
logger.info("SavedModel created at {}.".format(filename)) logger.info("SavedModel created at {}.".format(filename))
...@@ -236,7 +236,7 @@ if __name__ == '__main__': ...@@ -236,7 +236,7 @@ if __name__ == '__main__':
sess = tf.Session() sess = tf.Session()
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
with sess.as_default(): with sess.as_default():
for k in range(20): for _ in range(20):
min_op.run() min_op.run()
print(x.eval()) print(x.eval())
print(tf.train.get_or_create_global_step().eval()) print(tf.train.get_or_create_global_step().eval())
...@@ -91,12 +91,12 @@ class SaverRestore(SessionInit): ...@@ -91,12 +91,12 @@ class SaverRestore(SessionInit):
""" """
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`. Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
""" """
def __init__(self, model_path, prefix=None, ignore=[]): def __init__(self, model_path, prefix=None, ignore=()):
""" """
Args: Args:
model_path (str): a model name (model-xxxx) or a ``checkpoint`` file. model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint. prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint.
ignore (list[str]): list of tensor names that should be ignored during loading, e.g. learning-rate ignore (tuple[str]): tensor names that should be ignored during loading, e.g. learning-rate
""" """
if model_path.endswith('.npy') or model_path.endswith('.npz'): if model_path.endswith('.npy') or model_path.endswith('.npz'):
logger.warn("SaverRestore expect a TF checkpoint, but got a model path '{}'.".format(model_path) + logger.warn("SaverRestore expect a TF checkpoint, but got a model path '{}'.".format(model_path) +
......
...@@ -261,7 +261,7 @@ def find_library_full_path(name): ...@@ -261,7 +261,7 @@ def find_library_full_path(name):
if 'lib' + name + '.so' in basename: if 'lib' + name + '.so' in basename:
if os.path.isfile(sofile): if os.path.isfile(sofile):
return os.path.realpath(sofile) return os.path.realpath(sofile)
except (OSError, IOError): except IOError:
# can fail in certain environment (e.g. chroot) # can fail in certain environment (e.g. chroot)
# if the pids are incorrectly mapped # if the pids are incorrectly mapped
pass pass
......
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
# See https://pep8.readthedocs.io/en/latest/intro.html#error-codes # See https://pep8.readthedocs.io/en/latest/intro.html#error-codes
ignore = E265,E741,E742,E743,W504,W605,C408 ignore = E265,E741,E742,E743,W504,W605,C408,B007,B008
exclude = .git, exclude = .git,
__init__.py, __init__.py,
setup.py, setup.py,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment