Commit 3efce3ae authored by Yuxin Wu's avatar Yuxin Wu

fix tower-summary bug

parent d5fe531d
...@@ -100,8 +100,6 @@ def get_data(train_or_test): ...@@ -100,8 +100,6 @@ def get_data(train_or_test):
ds = PrefetchData(ds, 10, 5) ds = PrefetchData(ds, 10, 5)
return ds return ds
def get_config(): def get_config():
# prepare dataset # prepare dataset
dataset_train = get_data('train') dataset_train = get_data('train')
......
...@@ -11,6 +11,12 @@ import json ...@@ -11,6 +11,12 @@ import json
__all__ = ['VisualQA'] __all__ = ['VisualQA']
def read_json(fname):
f = open(fname)
ret = json.load(f)
f.close()
return ret
# TODO shuffle # TODO shuffle
class VisualQA(DataFlow): class VisualQA(DataFlow):
""" """
...@@ -19,12 +25,11 @@ class VisualQA(DataFlow): ...@@ -19,12 +25,11 @@ class VisualQA(DataFlow):
""" """
def __init__(self, question_file, annotation_file): def __init__(self, question_file, annotation_file):
with timed_operation('Reading VQA JSON file'): with timed_operation('Reading VQA JSON file'):
qobj = json.load(open(question_file)) qobj, aobj = list(map(read_json, [question_file, annotation_file]))
self.task_type = qobj['task_type'] self.task_type = qobj['task_type']
self.questions = qobj['questions'] self.questions = qobj['questions']
self._size = len(self.questions) self._size = len(self.questions)
aobj = json.load(open(annotation_file))
self.anno = aobj['annotations'] self.anno = aobj['annotations']
assert len(self.anno) == len(self.questions), \ assert len(self.anno) == len(self.questions), \
"{}!={}".format(len(self.anno), len(self.questions)) "{}!={}".format(len(self.anno), len(self.questions))
...@@ -49,7 +54,7 @@ class VisualQA(DataFlow): ...@@ -49,7 +54,7 @@ class VisualQA(DataFlow):
""" """
cnt = Counter() cnt = Counter()
for anno in self.anno: for anno in self.anno:
cnt[anno['multiple_choice_answer']] += 1 cnt[anno['multiple_choice_answer'].lower()] += 1
return [k[0] for k in cnt.most_common(n)] return [k[0] for k in cnt.most_common(n)]
def get_common_question_words(self, n): def get_common_question_words(self, n):
......
...@@ -107,7 +107,6 @@ class SaverRestore(SessionInit): ...@@ -107,7 +107,6 @@ class SaverRestore(SessionInit):
logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name)) logger.warn("Param {} not found in checkpoint! Will not restore.".format(v.op.name))
return var_dict return var_dict
class ParamRestore(SessionInit): class ParamRestore(SessionInit):
""" """
Restore trainable variables from a dictionary. Restore trainable variables from a dictionary.
......
...@@ -69,7 +69,7 @@ def add_param_summary(summary_lists): ...@@ -69,7 +69,7 @@ def add_param_summary(summary_lists):
for act in actions: for act in actions:
perform(p, act) perform(p, act)
# TODO use name of cost_var # TODO get rid of the cost_var thing...
def summary_moving_average(cost_var): def summary_moving_average(cost_var):
""" Create a MovingAverage op and summary for all variables in """ Create a MovingAverage op and summary for all variables in
MOVING_SUMMARY_VARS_KEY, as well as `cost_var`. MOVING_SUMMARY_VARS_KEY, as well as `cost_var`.
......
...@@ -137,8 +137,8 @@ class QueueInputTrainer(Trainer): ...@@ -137,8 +137,8 @@ class QueueInputTrainer(Trainer):
kept_summaries[k] = copy.copy(tf.get_collection(k)) kept_summaries[k] = copy.copy(tf.get_collection(k))
logger.info("Graph built for tower {}.".format(i)) logger.info("Graph built for tower {}.".format(i))
for k in coll_keys: for k in coll_keys:
del tf.get_collection(k)[:] del tf.get_collection_ref(k)[:]
tf.get_collection(k).extend(kept_summaries[k]) tf.get_collection_ref(k).extend(kept_summaries[k])
grads = QueueInputTrainer._average_grads(grad_list) grads = QueueInputTrainer._average_grads(grad_list)
cost_var = cost_var_t0 cost_var = cost_var_t0
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment