Commit bf8acbfb authored by Yuxin Wu's avatar Yuxin Wu

lint with flake8-bugbear

parent 552c2b3b
......@@ -90,7 +90,7 @@ if __name__ == '__main__':
op_name = e.args[0]
_import_external_ops(op_name)
except tf.errors.NotFoundError as e:
_import_external_ops(e.message)
_import_external_ops(str(e))
else:
break
......
......@@ -55,11 +55,11 @@ class HyperParam(object):
class GraphVarParam(HyperParam):
""" A variable in the graph (e.g. learning_rate) can be a hyperparam."""
def __init__(self, name, shape=[]):
def __init__(self, name, shape=()):
"""
Args:
name(str): name of the variable.
shape(list): shape of the variable.
shape(tuple): shape of the variable.
"""
self.name = name
self.shape = shape
......
......@@ -244,7 +244,7 @@ class GPUMemoryTracker(Callback):
_chief_only = False
def __init__(self, devices=[0]):
def __init__(self, devices=(0,)):
"""
Args:
devices([int] or [str]): list of GPU devices to track memory on.
......
......@@ -80,7 +80,7 @@ class ModelSaver(Callback):
global_step=tf.train.get_global_step(),
write_meta_graph=False)
logger.info("Model saved to %s." % tf.train.get_checkpoint_state(self.checkpoint_dir).model_checkpoint_path)
except (OSError, IOError, tf.errors.PermissionDeniedError,
except (IOError, tf.errors.PermissionDeniedError,
tf.errors.ResourceExhaustedError): # disk error sometimes.. just ignore it
logger.exception("Exception in ModelSaver!")
......
......@@ -51,10 +51,10 @@ class ProgressBar(Callback):
_chief_only = False
def __init__(self, names=[]):
def __init__(self, names=()):
"""
Args:
names(list): list of string, the names of the tensors to monitor
names(tuple[str]): the names of the tensors to monitor
on the progress bar.
"""
super(ProgressBar, self).__init__()
......
......@@ -82,7 +82,7 @@ except ImportError:
if __name__ == "__main__":
ds = Caltech101Silhouettes("train")
ds.reset_state()
for (img, label) in ds:
for _ in ds:
from IPython import embed
embed()
......
......@@ -300,7 +300,7 @@ if __name__ == '__main__':
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', shuffle=False)
ds.reset_state()
for k in ds:
for _ in ds:
from IPython import embed
embed()
break
......@@ -132,7 +132,7 @@ class FashionMnist(Mnist):
if __name__ == '__main__':
ds = Mnist('train')
ds.reset_state()
for (img, label) in ds:
for _ in ds:
from IPython import embed
embed()
break
......@@ -145,7 +145,7 @@ class LMDBData(RNGDataFlow):
with self._guard:
if not self._shuffle:
c = self._txn.cursor()
while c.next():
while next(c):
k, v = c.item()
if k != b'__keys__':
yield [k, v]
......
......@@ -50,7 +50,7 @@ def _default_repr(self):
defaults[k] = argspec.kwonlydefaults[k]
argstr = []
for idx, f in enumerate(fields):
for f in fields:
assert hasattr(self, f), \
"Attribute {} in {} not found! Default __repr__ only works if " \
"the instance has attributes that match the constructor.".format(f, classname)
......
......@@ -375,7 +375,7 @@ if __name__ == '__main__':
draw_points(orig_image, coords)
print(coords)
for k in range(1):
for _ in range(1):
coords = trans.apply_coords(coords)
image = trans.apply_image(image)
print(coords)
......
......@@ -570,7 +570,7 @@ class StagingInput(FeedfreeInput):
def _prefill(self, sess):
logger.info("Pre-filling StagingArea ...")
for k in range(self.nr_stage):
for _ in range(self.nr_stage):
self.stage_op.run(session=sess)
logger.info("{} element{} put into StagingArea on each tower.".format(
self.nr_stage, "s were" if self.nr_stage > 1 else " was"))
......
......@@ -80,7 +80,7 @@ class CollectionGuard(object):
original = None
def __init__(self, name, check_diff,
freeze_keys=[],
freeze_keys=(),
diff_whitelist=None):
"""
Args:
......
......@@ -89,7 +89,7 @@ class ModelExporter(object):
logger.info("Output graph written to {}.".format(filename))
def export_serving(self, filename,
tags=[tf.saved_model.SERVING if is_tfv2() else tf.saved_model.tag_constants.SERVING],
tags=(tf.saved_model.SERVING if is_tfv2() else tf.saved_model.tag_constants.SERVING,),
signature_name='prediction_pipeline'):
"""
Converts a checkpoint and graph to a servable for TensorFlow Serving.
......@@ -97,7 +97,7 @@ class ModelExporter(object):
Args:
filename (str): path for export directory
tags (list): list of user specified tags
tags (tuple): tuple of user specified tags
signature_name (str): name of signature for prediction
Note:
......@@ -140,7 +140,7 @@ class ModelExporter(object):
method_name=tfv1.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables(
sess, tags,
sess, list(tags),
signature_def_map={signature_name: prediction_signature})
builder.save()
logger.info("SavedModel created at {}.".format(filename))
......@@ -236,7 +236,7 @@ if __name__ == '__main__':
sess = tf.Session()
sess.run(tf.global_variables_initializer())
with sess.as_default():
for k in range(20):
for _ in range(20):
min_op.run()
print(x.eval())
print(tf.train.get_or_create_global_step().eval())
......@@ -91,12 +91,12 @@ class SaverRestore(SessionInit):
"""
Restore a tensorflow checkpoint saved by :class:`tf.train.Saver` or :class:`ModelSaver`.
"""
def __init__(self, model_path, prefix=None, ignore=[]):
def __init__(self, model_path, prefix=None, ignore=()):
"""
Args:
model_path (str): a model name (model-xxxx) or a ``checkpoint`` file.
prefix (str): during restore, add a ``prefix/`` for every variable in this checkpoint.
ignore (list[str]): list of tensor names that should be ignored during loading, e.g. learning-rate
ignore (tuple[str]): tensor names that should be ignored during loading, e.g. learning-rate
"""
if model_path.endswith('.npy') or model_path.endswith('.npz'):
logger.warn("SaverRestore expect a TF checkpoint, but got a model path '{}'.".format(model_path) +
......
......@@ -261,7 +261,7 @@ def find_library_full_path(name):
if 'lib' + name + '.so' in basename:
if os.path.isfile(sofile):
return os.path.realpath(sofile)
except (OSError, IOError):
except IOError:
# can fail in certain environment (e.g. chroot)
# if the pids are incorrectly mapped
pass
......
[flake8]
max-line-length = 120
# See https://pep8.readthedocs.io/en/latest/intro.html#error-codes
ignore = E265,E741,E742,E743,W504,W605,C408
ignore = E265,E741,E742,E743,W504,W605,C408,B007,B008
exclude = .git,
__init__.py,
setup.py,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment