Commit 552c2b3b authored by Yuxin Wu's avatar Yuxin Wu

lint with flake8-comprehensions

parent a9950705
...@@ -17,7 +17,7 @@ jobs: ...@@ -17,7 +17,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install flake8 pip install flake8 flake8-comprehensions flake8-bugbear
flake8 --version flake8 --version
- name: Lint - name: Lint
run: | run: |
......
...@@ -108,8 +108,8 @@ if __name__ == '__main__': ...@@ -108,8 +108,8 @@ if __name__ == '__main__':
if len(set(var_to_dump)) != len(var_to_dump): if len(set(var_to_dump)) != len(var_to_dump):
logger.warn("TRAINABLE and MODEL variables have duplication!") logger.warn("TRAINABLE and MODEL variables have duplication!")
var_to_dump = list(set(var_to_dump)) var_to_dump = list(set(var_to_dump))
globvarname = set([k.name for k in tf.global_variables()]) globvarname = {k.name for k in tf.global_variables()}
var_to_dump = set([k.name for k in var_to_dump if k.name in globvarname]) var_to_dump = {k.name for k in var_to_dump if k.name in globvarname}
for name in var_to_dump: for name in var_to_dump:
assert name in dic, "Variable {} not found in the model!".format(name) assert name in dic, "Variable {} not found in the model!".format(name)
......
...@@ -425,7 +425,7 @@ class ScalarPrinter(MonitorBase): ...@@ -425,7 +425,7 @@ class ScalarPrinter(MonitorBase):
def compile_regex(rs): def compile_regex(rs):
if rs is None: if rs is None:
return None return None
rs = set([re.compile(r) for r in rs]) rs = {re.compile(r) for r in rs}
return rs return rs
self._whitelist = compile_regex(whitelist) self._whitelist = compile_regex(whitelist)
......
...@@ -49,12 +49,12 @@ class KerasModelCaller(object): ...@@ -49,12 +49,12 @@ class KerasModelCaller(object):
""" """
reuse = tf.get_variable_scope().reuse reuse = tf.get_variable_scope().reuse
old_trainable_names = set([x.name for x in tf.trainable_variables()]) old_trainable_names = {x.name for x in tf.trainable_variables()}
trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES]) trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES])
update_ops_backup = backup_collection([tf.GraphKeys.UPDATE_OPS]) update_ops_backup = backup_collection([tf.GraphKeys.UPDATE_OPS])
def post_process_model(model): def post_process_model(model):
added_trainable_names = set([x.name for x in tf.trainable_variables()]) added_trainable_names = {x.name for x in tf.trainable_variables()}
restore_collection(trainable_backup) restore_collection(trainable_backup)
for v in model.weights: for v in model.weights:
...@@ -62,7 +62,7 @@ class KerasModelCaller(object): ...@@ -62,7 +62,7 @@ class KerasModelCaller(object):
# We put M.weights into the collection instead. # We put M.weights into the collection instead.
if v.name not in old_trainable_names and v.name in added_trainable_names: if v.name not in old_trainable_names and v.name in added_trainable_names:
tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v) tf.add_to_collection(tf.GraphKeys.TRAINABLE_VARIABLES, v)
new_trainable_names = set([x.name for x in tf.trainable_variables()]) new_trainable_names = {x.name for x in tf.trainable_variables()}
for n in added_trainable_names: for n in added_trainable_names:
if n not in new_trainable_names: if n not in new_trainable_names:
......
...@@ -431,7 +431,7 @@ class RandomChooseData(RNGDataFlow): ...@@ -431,7 +431,7 @@ class RandomChooseData(RNGDataFlow):
""" """
super(RandomChooseData, self).__init__() super(RandomChooseData, self).__init__()
if isinstance(df_lists[0], (tuple, list)): if isinstance(df_lists[0], (tuple, list)):
assert sum([v[1] for v in df_lists]) == 1.0 assert sum(v[1] for v in df_lists) == 1.0
self.df_lists = df_lists self.df_lists = df_lists
else: else:
prob = 1.0 / len(df_lists) prob = 1.0 / len(df_lists)
...@@ -512,7 +512,7 @@ class ConcatData(DataFlow): ...@@ -512,7 +512,7 @@ class ConcatData(DataFlow):
d.reset_state() d.reset_state()
def __len__(self): def __len__(self):
return sum([len(x) for x in self.df_lists]) return sum(len(x) for x in self.df_lists)
def __iter__(self): def __iter__(self):
for d in self.df_lists: for d in self.df_lists:
...@@ -565,7 +565,7 @@ class JoinData(DataFlow): ...@@ -565,7 +565,7 @@ class JoinData(DataFlow):
""" """
Return the minimum size among all. Return the minimum size among all.
""" """
return min([len(k) for k in self.df_lists]) return min(len(k) for k in self.df_lists)
def __iter__(self): def __iter__(self):
itrs = [k.__iter__() for k in self.df_lists] itrs = [k.__iter__() for k in self.df_lists]
......
...@@ -45,7 +45,7 @@ class HDF5Data(RNGDataFlow): ...@@ -45,7 +45,7 @@ class HDF5Data(RNGDataFlow):
logger.info("Loading {} to memory...".format(filename)) logger.info("Loading {} to memory...".format(filename))
self.dps = [self.f[k].value for k in data_paths] self.dps = [self.f[k].value for k in data_paths]
lens = [len(k) for k in self.dps] lens = [len(k) for k in self.dps]
assert all([k == lens[0] for k in lens]) assert all(k == lens[0] for k in lens)
self._size = lens[0] self._size = lens[0]
self.shuffle = shuffle self.shuffle = shuffle
......
...@@ -230,7 +230,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase): ...@@ -230,7 +230,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
list of (shadow_model_var, local_model_var) used for syncing. list of (shadow_model_var, local_model_var) used for syncing.
""" """
G = tf.get_default_graph() G = tf.get_default_graph()
curr_shadow_vars = set([v.name for v in shadow_vars]) curr_shadow_vars = {v.name for v in shadow_vars}
model_vars = tf.model_variables() model_vars = tf.model_variables()
shadow_model_vars = [] shadow_model_vars = []
for v in model_vars: for v in model_vars:
...@@ -346,7 +346,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase): ...@@ -346,7 +346,7 @@ class DistributedReplicatedBuilder(DataParallelBuilder, DistributedBuilderBase):
return s[:-2] return s[:-2]
return s return s
local_vars = tf.local_variables() local_vars = tf.local_variables()
local_var_by_name = dict([(strip_port(v.name), v) for v in local_vars]) local_var_by_name = {strip_port(v.name): v for v in local_vars}
ops = [] ops = []
nr_shadow_vars = len(self._shadow_vars) nr_shadow_vars = len(self._shadow_vars)
for v in self._shadow_vars: for v in self._shadow_vars:
......
...@@ -63,7 +63,7 @@ class DataParallelBuilder(GraphBuilder): ...@@ -63,7 +63,7 @@ class DataParallelBuilder(GraphBuilder):
return re.sub('tower[0-9]+/', '', x.op.name) return re.sub('tower[0-9]+/', '', x.op.name)
if len(set(nvars)) != 1: if len(set(nvars)) != 1:
names_per_gpu = [set([basename(k[1]) for k in grad_and_vars]) for grad_and_vars in grad_list] names_per_gpu = [{basename(k[1]) for k in grad_and_vars} for grad_and_vars in grad_list]
inters = copy.copy(names_per_gpu[0]) inters = copy.copy(names_per_gpu[0])
for s in names_per_gpu: for s in names_per_gpu:
inters &= s inters &= s
...@@ -247,11 +247,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -247,11 +247,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
DataParallelBuilder._check_grad_list(grad_list) DataParallelBuilder._check_grad_list(grad_list)
dtypes = set([x[0].dtype.base_dtype for x in grad_list[0]]) dtypes = {x[0].dtype.base_dtype for x in grad_list[0]}
dtypes_nccl_supported = [tf.float32, tf.float64] dtypes_nccl_supported = [tf.float32, tf.float64]
if get_tf_version_tuple() >= (1, 8): if get_tf_version_tuple() >= (1, 8):
dtypes_nccl_supported.append(tf.float16) dtypes_nccl_supported.append(tf.float16)
valid_for_nccl = all([k in dtypes_nccl_supported for k in dtypes]) valid_for_nccl = all(k in dtypes_nccl_supported for k in dtypes)
if self._mode == 'nccl' and not valid_for_nccl: if self._mode == 'nccl' and not valid_for_nccl:
logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'") logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'")
self._mode = 'cpu' self._mode = 'cpu'
...@@ -314,8 +314,8 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder): ...@@ -314,8 +314,8 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
""" """
# literally all variables, because it's better to sync optimizer-internal variables as well # literally all variables, because it's better to sync optimizer-internal variables as well
all_vars = tf.global_variables() + tf.local_variables() all_vars = tf.global_variables() + tf.local_variables()
var_by_name = dict([(v.name, v) for v in all_vars]) var_by_name = {v.name: v for v in all_vars}
trainable_names = set([x.name for x in tf.trainable_variables()]) trainable_names = {x.name for x in tf.trainable_variables()}
post_init_ops = [] post_init_ops = []
def log_failure(name, reason): def log_failure(name, reason):
......
...@@ -25,7 +25,7 @@ def _replace_global_by_local(kwargs): ...@@ -25,7 +25,7 @@ def _replace_global_by_local(kwargs):
if 'collections' in kwargs: if 'collections' in kwargs:
collections = kwargs['collections'] collections = kwargs['collections']
if not collections: if not collections:
collections = set([tf.GraphKeys.GLOBAL_VARIABLES]) collections = set(tf.GraphKeys.GLOBAL_VARIABLES)
else: else:
collections = set(collections.copy()) collections = set(collections.copy())
collections.remove(tf.GraphKeys.GLOBAL_VARIABLES) collections.remove(tf.GraphKeys.GLOBAL_VARIABLES)
...@@ -343,7 +343,7 @@ class GradientPacker(object): ...@@ -343,7 +343,7 @@ class GradientPacker(object):
logger.info("Skip GradientPacker due to too few gradients.") logger.info("Skip GradientPacker due to too few gradients.")
return False return False
# should have the same dtype # should have the same dtype
dtypes = set([g.dtype for g in grads]) dtypes = {g.dtype for g in grads}
if len(dtypes) != 1: if len(dtypes) != 1:
logger.info("Skip GradientPacker due to inconsistent gradient types.") logger.info("Skip GradientPacker due to inconsistent gradient types.")
return False return False
......
...@@ -471,7 +471,7 @@ class TFDatasetInput(FeedfreeInput): ...@@ -471,7 +471,7 @@ class TFDatasetInput(FeedfreeInput):
self._spec = input_signature self._spec = input_signature
if self._dataset is not None: if self._dataset is not None:
types = self._dataset.output_types types = self._dataset.output_types
spec_types = tuple([k.dtype for k in input_signature]) spec_types = tuple(k.dtype for k in input_signature)
assert len(types) == len(spec_types), \ assert len(types) == len(spec_types), \
"Dataset and input signature have different length! {} != {}".format( "Dataset and input signature have different length! {} != {}".format(
len(types), len(spec_types)) len(types), len(spec_types))
......
...@@ -100,7 +100,7 @@ def Conv2D( ...@@ -100,7 +100,7 @@ def Conv2D(
filter_shape = kernel_shape + [in_channel / split, out_channel] filter_shape = kernel_shape + [in_channel / split, out_channel]
stride = shape4d(strides, data_format=data_format) stride = shape4d(strides, data_format=data_format)
kwargs = dict(data_format=data_format) kwargs = {"data_format": data_format}
if get_tf_version_tuple() >= (1, 5): if get_tf_version_tuple() >= (1, 5):
kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format) kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format)
......
...@@ -42,14 +42,14 @@ def describe_trainable_vars(): ...@@ -42,14 +42,14 @@ def describe_trainable_vars():
data.append([get_op_tensor_name(v.name)[0], shape, ele, v.device, v.dtype.base_dtype.name]) data.append([get_op_tensor_name(v.name)[0], shape, ele, v.device, v.dtype.base_dtype.name])
headers = ['name', 'shape', '#elements', 'device', 'dtype'] headers = ['name', 'shape', '#elements', 'device', 'dtype']
dtypes = list(set([x[4] for x in data])) dtypes = list({x[4] for x in data})
if len(dtypes) == 1 and dtypes[0] == "float32": if len(dtypes) == 1 and dtypes[0] == "float32":
# don't log the dtype if all vars are float32 (default dtype) # don't log the dtype if all vars are float32 (default dtype)
for x in data: for x in data:
del x[4] del x[4]
del headers[4] del headers[4]
devices = set([x[3] for x in data]) devices = {x[3] for x in data}
if len(devices) == 1: if len(devices) == 1:
# don't log the device if all vars on the same device # don't log the device if all vars on the same device
for x in data: for x in data:
......
...@@ -150,7 +150,7 @@ def dump_session_params(path): ...@@ -150,7 +150,7 @@ def dump_session_params(path):
var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES)) var.extend(tf.get_collection(tf.GraphKeys.MODEL_VARIABLES))
# TODO dedup # TODO dedup
assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!" assert len(set(var)) == len(var), "TRAINABLE and MODEL variables have duplication!"
gvars = set([k.name for k in tf.global_variables()]) gvars = {k.name for k in tf.global_variables()}
var = [v for v in var if v.name in gvars] var = [v for v in var if v.name in gvars]
result = {} result = {}
for v in var: for v in var:
...@@ -167,7 +167,7 @@ def save_chkpt_vars(dic, path): ...@@ -167,7 +167,7 @@ def save_chkpt_vars(dic, path):
path: save as npz if the name ends with '.npz', otherwise save as a checkpoint. path: save as npz if the name ends with '.npz', otherwise save as a checkpoint.
""" """
logger.info("Variables to save to {}:".format(path)) logger.info("Variables to save to {}:".format(path))
keys = sorted(list(dic.keys())) keys = sorted(dic.keys())
logger.info(pprint.pformat(keys)) logger.info(pprint.pformat(keys))
assert not path.endswith('.npy') assert not path.endswith('.npy')
......
...@@ -91,8 +91,8 @@ def _pad_patch_list(plist, bgcolor): ...@@ -91,8 +91,8 @@ def _pad_patch_list(plist, bgcolor):
plist = _pad_channel(plist) plist = _pad_channel(plist)
shapes = [x.shape for x in plist] shapes = [x.shape for x in plist]
ph = max([s[0] for s in shapes]) ph = max(s[0] for s in shapes)
pw = max([s[1] for s in shapes]) pw = max(s[1] for s in shapes)
ret = np.zeros((len(plist), ph, pw, 3), dtype=plist[0].dtype) ret = np.zeros((len(plist), ph, pw, 3), dtype=plist[0].dtype)
ret[:, :, :] = bgcolor ret[:, :, :] = bgcolor
......
...@@ -39,11 +39,11 @@ def benchmark_serializer(dumps, loads, data, num): ...@@ -39,11 +39,11 @@ def benchmark_serializer(dumps, loads, data, num):
def display_results(name, results): def display_results(name, results):
logger.info("Encoding benchmark for {}:".format(name)) logger.info("Encoding benchmark for {}:".format(name))
data = sorted([(x, y[0]) for x, y in results], key=operator.itemgetter(1)) data = sorted(((x, y[0]) for x, y in results), key=operator.itemgetter(1))
print(tabulate(data, floatfmt='.5f')) print(tabulate(data, floatfmt='.5f'))
logger.info("Decoding benchmark for {}:".format(name)) logger.info("Decoding benchmark for {}:".format(name))
data = sorted([(x, y[1]) for x, y in results], key=operator.itemgetter(1)) data = sorted(((x, y[1]) for x, y in results), key=operator.itemgetter(1))
print(tabulate(data, floatfmt='.5f')) print(tabulate(data, floatfmt='.5f'))
...@@ -64,8 +64,8 @@ def fake_json_data(): ...@@ -64,8 +64,8 @@ def fake_json_data():
pellentesque quis sollicitudin id, adipiscing. pellentesque quis sollicitudin id, adipiscing.
""" * 100, """ * 100,
'list': list(range(100)) * 500, 'list': list(range(100)) * 500,
'dict': dict((str(i), 'a') for i in range(50000)), 'dict': {str(i): 'a' for i in range(50000)},
'dict2': dict((i, 'a') for i in range(50000)), 'dict2': {i: 'a' for i in range(50000)},
'int': 3000, 'int': 3000,
'float': 100.123456 'float': 100.123456
} }
......
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
# See https://pep8.readthedocs.io/en/latest/intro.html#error-codes # See https://pep8.readthedocs.io/en/latest/intro.html#error-codes
ignore = E265,E741,E742,E743,W504,W605 ignore = E265,E741,E742,E743,W504,W605,C408
exclude = .git, exclude = .git,
__init__.py, __init__.py,
setup.py, setup.py,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment