Commit f18314d6 authored by Yuxin Wu's avatar Yuxin Wu

fix validation bug again

parent 741f404d
...@@ -48,6 +48,7 @@ class ValidationCallback(PeriodicCallback): ...@@ -48,6 +48,7 @@ class ValidationCallback(PeriodicCallback):
output_vars = self._get_output_vars() output_vars = self._get_output_vars()
output_vars.append(self.cost_var) output_vars.append(self.cost_var)
sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar: with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
feed = dict(itertools.izip(self.input_vars, dp)) feed = dict(itertools.izip(self.input_vars, dp))
...@@ -55,7 +56,7 @@ class ValidationCallback(PeriodicCallback): ...@@ -55,7 +56,7 @@ class ValidationCallback(PeriodicCallback):
batch_size = dp[0].shape[0] # assume batched input batch_size = dp[0].shape[0] # assume batched input
cnt += batch_size cnt += batch_size
outputs = self.sess.run(output_vars, feed_dict=feed) outputs = sess.run(output_vars, feed_dict=feed)
cost = outputs[-1] cost = outputs[-1]
# each batch might not have the same size in validation # each batch might not have the same size in validation
cost_sum += cost * batch_size cost_sum += cost * batch_size
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment