Commit 552c2b3b authored by Yuxin Wu's avatar Yuxin Wu

lint with flake8-comprehensions

parent a9950705
......@@ -17,7 +17,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8
pip install flake8 flake8-comprehensions flake8-bugbear
flake8 --version
- name: Lint
run: |
......
......@@ -108,8 +108,8 @@ if __name__ == '__main__':
if len(set(var_to_dump)) != len(var_to_dump):
logger.warn("TRAINABLE and MODEL variables have duplication!")
var_to_dump = list(set(var_to_dump))
globvarname = set([k.name for k in tf.global_variables()])
var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname])
globvarname = {k.name for k in tf.global_variables()}
var_to_dump = {k.name for k in var_to_dump if k.name in globvarname}
for name in var_to_dump:
assert name in dic, "Variable {} not found in the model!".format(name)
......
......@@ -425,7 +425,7 @@ class ScalarPrinter(MonitorBase):
def compile_regex(rs):
if rs is None:
return None
rs = set([re.compile(r) for r in rs])
rs = {re.compile(r) for r in rs}
return rs
self._whitelist = compile_regex(whitelist)
......
......@@ -49,12 +49,12 @@ class KerasModelCaller(object):
"""
reuse = tf.get_variable_scope().reuse
old_trainable_names = set([x.name for x in tf.trainable_variables()])
old_trainable_names = {x.name for x in tf.trainable_variables()}
trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES])
update_ops_backup = backup_collection([tf.GraphKeys.UPDATE_OPS])
def post_process_model(model):
added_trainable_names = set([x.name for x in tf.trainable_variables()])
added_trainable_names = {x.name for x in tf.trainable_variables()}
restore_collection(trainable_backup)
for v in model.weights:
......@@ -62,7 +62,7 @@ class KerasModelCaller(object):
# We put M.weights into the collection instead.
if v.name not in old_trainable_names and v.name in added_trainable_names:
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v)
new_trainable_names = set([x.name for x in tf.trainable_variables()])
new_trainable_names = {x.name for x in tf.trainable_variables()}
for n in added_trainable_names:
if n not in new_trainable_names:
......
......@@ -431,7 +431,7 @@ class RandomChooseData(RNGDataFlow):
"""
super(RandomChooseData, self).__init__()
if isinstance(df_lists[0], (tuple, list)):
assert sum([v[1] for v in df_lists]) == 1.0
assert sum(v[1] for v in df_lists) == 1.0
self.df_lists = df_lists
else:
prob = 1.0 / len(df_lists)
......@@ -512,7 +512,7 @@ class ConcatData(DataFlow):
d.reset_state()
def __len__(self):
return sum([len(x) for x in self.df_lists])
return sum(len(x) for x in self.df_lists)
def __iter__(self):
for d in self.df_lists:
......@@ -565,7 +565,7 @@ class JoinData(DataFlow):
"""
Return the minimum size among all.
"""
return min([len(k) for k in self.df_lists])
return min(len(k) for k in self.df_lists)
def __iter__(self):
itrs = [k.__iter__() for k in self.df_lists]
......
......@@ -45,7 +45,7 @@ class HDF5Data(RNGDataFlow):
logger.info("Loading {} to memory...".format(filename))
self.dps = [self.f[k].value for k in data_paths]
lens = [len(k) for k in self.dps]
assert all([k == lens[0] for k in lens])
assert all(k == lens[0] for k in lens)
self._size = lens[0]
self.shuffle = shuffle
......
......@@ -230,7 +230,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
list of (shadow_model_var, local_model_var) used for syncing.
"""
G = tf.get_default_graph()
curr_shadow_vars = set([v.name for v in shadow_vars])
curr_shadow_vars = {v.name for v in shadow_vars}
model_vars = tf.model_variables()
shadow_model_vars = []
for v in model_vars:
......@@ -346,7 +346,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
return s[:-2]
return s
local_vars = tf.local_variables()
local_var_by_name = dict([(strip_port(v.name), v) for v in local_vars])
local_var_by_name = {strip_port(v.name): v for v in local_vars}
ops = []
nr_shadow_vars = len(self._shadow_vars)
for v in self._shadow_vars:
......
......@@ -63,7 +63,7 @@ class DataParallelBuilder(GraphBuilder):
return re.sub('tower[0-9]+/', '', x.op.name)
if len(set(nvars)) != 1:
names_per_gpu = [set([basename(k[1]) for k in grad_and_vars]) for grad_and_vars in grad_list]
names_per_gpu = [{basename(k[1]) for k in grad_and_vars} for grad_and_vars in grad_list]
inters = copy.copy(names_per_gpu[0])
for s in names_per_gpu:
inters &= s
......@@ -247,11 +247,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
DataParallelBuilder._check_grad_list(grad_list)
dtypes = set([x[0].dtype.base_dtype for x in grad_list[0]])
dtypes = {x[0].dtype.base_dtype for x in grad_list[0]}
dtypes_nccl_supported = [tf.float32, tf.float64]
if get_tf_version_tuple() >= (1, 8):
dtypes_nccl_supported.append(tf.float16)
valid_for_nccl = all([k in dtypes_nccl_supported for k in dtypes])
valid_for_nccl = all(k in dtypes_nccl_supported for k in dtypes)
if self._mode == 'nccl' and not valid_for_nccl:
logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'")
self._mode = 'cpu'
......@@ -314,8 +314,8 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
"""
# literally all variables, because it's better to sync optimizer-internal variables as well
all_vars = tf.global_variables() + tf.local_variables()
var_by_name = dict([(v.name, v) for v in all_vars])
trainable_names = set([x.name for x in tf.trainable_variables()])
var_by_name = {v.name: v for v in all_vars}
trainable_names = {x.name for x in tf.trainable_variables()}
post_init_ops = []
def log_failure(name, reason):
......
......@@ -25,7 +25,7 @@ def _replace_global_by_local(kwargs):
if 'collections' in kwargs:
collections = kwargs['collections']
if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES])
collections = set(tf.GraphKeys.GLOBAL_VARIABLES)
else:
collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
......@@ -343,7 +343,7 @@ class GradientPacker(object):
logger.info("Skip GradientPacker due to too few gradients.")
return False
# should have the same dtype
dtypes = set([g.dtype for g in grads])
dtypes = {g.dtype for g in grads}
if len(dtypes) != 1:
logger.info("Skip GradientPacker due to inconsistent gradient types.")
return False
......
......@@ -471,7 +471,7 @@ class TFDatasetInput(FeedfreeInput):
self._spec = input_signature
if self._dataset is not None:
types = self._dataset.output_types
spec_types = tuple([k.dtype for k in input_signature])
spec_types = tuple(k.dtype for k in input_signature)
assert len(types) == len(spec_types), \
"Dataset and input signature have different length! {} != {}".format(
len(types), len(spec_types))
......
......@@ -100,7 +100,7 @@ def Conv2D(
filter_shape = kernel_shape + [in_channel / split, out_channel]
stride = shape4d(strides, data_format=data_format)
kwargs = dict(data_format=data_format)
kwargs = {"data_format": data_format}
if get_tf_version_tuple() >= (1, 5):
kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format)
......
......@@ -42,14 +42,14 @@ def describe_trainable_vars():
data.append([get_op_tensor_name(v.name)[0], shape, ele, v.device, v.dtype.base_dtype.name])
headers = ['name', 'shape', '#elements', 'device', 'dtype']
dtypes = list(set([x[4] for x in data]))
dtypes = list({x[4] for x in data})
if len(dtypes) == 1 and dtypes[0] == "float32":
# don't log the dtype if all vars are float32 (default dtype)
for x in data:
del x[4]
del headers[4]
devices = set([x[3] for x in data])
devices = {x[3] for x in data}
if len(devices) == 1:
# don't log the device if all vars on the same device
for x in data:
......
......@@ -150,7 +150,7 @@ def dump_session_params(path):
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
# TODO dedup
assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!"
gvars = set([k.name for k in tf.global_variables()])
gvars = {k.name for k in tf.global_variables()}
var = [v for v in var if v.name in gvars]
result = {}
for v in var:
......@@ -167,7 +167,7 @@ def save_chkpt_vars(dic, path):
path: save as npz if the name ends with '.npz', otherwise save as a checkpoint.
"""
logger.info("Variables to save to {}:".format(path))
keys = sorted(list(dic.keys()))
keys = sorted(dic.keys())
logger.info(pprint.pformat(keys))
assert not path.endswith('.npy')
......
......@@ -91,8 +91,8 @@ def _pad_patch_list(plist, bgcolor):
plist = _pad_channel(plist)
shapes = [x.shape for x in plist]
ph = max([s[0] for s in shapes])
pw = max([s[1] for s in shapes])
ph = max(s[0] for s in shapes)
pw = max(s[1] for s in shapes)
ret = np.zeros((len(plist), ph, pw, 3), dtype=plist[0].dtype)
ret[:, :, :] = bgcolor
......
......@@ -39,11 +39,11 @@ def benchmark_serializer(dumps, loads, data, num):
def display_results(name, results):
logger.info("Encoding benchmark for {}:".format(name))
data = sorted([(x, y[0]) for x, y in results], key=operator.itemgetter(1))
data = sorted(((x, y[0]) for x, y in results), key=operator.itemgetter(1))
print(tabulate(data, floatfmt='.5f'))
logger.info("Decoding benchmark for {}:".format(name))
data = sorted([(x, y[1]) for x, y in results], key=operator.itemgetter(1))
data = sorted(((x, y[1]) for x, y in results), key=operator.itemgetter(1))
print(tabulate(data, floatfmt='.5f'))
......@@ -64,8 +64,8 @@ def fake_json_data():
pellentesque quis sollicitudin id, adipiscing.
""" * 100,
'list': list(range(100)) * 500,
'dict': dict((str(i), 'a') for i in range(50000)),
'dict2': dict((i, 'a') for i in range(50000)),
'dict': {str(i): 'a' for i in range(50000)},
'dict2': {i: 'a' for i in range(50000)},
'int': 3000,
'float': 100.123456
}
......
[flake8]
max-line-length = 120
# See https://pep8.readthedocs.io/en/latest/intro.html#error-codes
ignore = E265,E741,E742,E743,W504,W605
ignore = E265,E741,E742,E743,W504,W605,C408
exclude = .git,
__init__.py,
setup.py,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment