Commit 132dcccd authored by Yuxin Wu's avatar Yuxin Wu

handle disk error

parent ecc33298
...@@ -115,3 +115,5 @@ weston, improving neural grammatical model flow images is allows belief networks ...@@ -115,3 +115,5 @@ weston, improving neural grammatical model flow images is allows belief networks
generating neural networks, there is not the initial particular marked pseudo-cameral rnns generating neural networks, there is not the initial particular marked pseudo-cameral rnns
sophett, pattern wlth designs for faster than the inference in deep learning. in nips (most), sophett, pattern wlth designs for faster than the inference in deep learning. in nips (most),
``` ```
See [blog of Andrej Karpathy](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) for more interesting stories on this topic.
...@@ -49,11 +49,14 @@ class ModelSaver(Callback): ...@@ -49,11 +49,14 @@ class ModelSaver(Callback):
return var_dict return var_dict
def _trigger_epoch(self): def _trigger_epoch(self):
try:
self.saver.save( self.saver.save(
tf.get_default_session(), tf.get_default_session(),
self.path, self.path,
global_step=self.global_step, global_step=self.global_step,
write_meta_graph=not self.meta_graph_written) write_meta_graph=not self.meta_graph_written)
except Exception: # disk error sometimes..
logger.exception("Exception in ModelSaver.trigger_epoch!")
if not self.meta_graph_written: if not self.meta_graph_written:
self.meta_graph_written = True self.meta_graph_written = True
......
...@@ -69,9 +69,12 @@ class StatHolder(object): ...@@ -69,9 +69,12 @@ class StatHolder(object):
def _write_stat(self): def _write_stat(self):
tmp_filename = self.filename + '.tmp' tmp_filename = self.filename + '.tmp'
try:
with open(tmp_filename, 'w') as f: with open(tmp_filename, 'w') as f:
json.dump(self.stat_history, f) json.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename) os.rename(tmp_filename, self.filename)
except IOError: # disk error sometimes..
logger.exception("Exception in StatHolder.finalize()!")
class StatPrinter(Callback): class StatPrinter(Callback):
""" """
......
...@@ -91,7 +91,10 @@ class EnqueueThread(threading.Thread): ...@@ -91,7 +91,10 @@ class EnqueueThread(threading.Thread):
except Exception: except Exception:
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
finally: finally:
try:
self.sess.run(self.close_op) self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
self.coord.request_stop() self.coord.request_stop()
logger.info("Enqueue Thread Exited.") logger.info("Enqueue Thread Exited.")
......
...@@ -3,59 +3,63 @@ ...@@ -3,59 +3,63 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
__all__ = ['StatCounter', 'Accuracy', 'BinaryStatistics'] __all__ = ['StatCounter', 'Accuracy', 'BinaryStatistics', 'RatioStatistics']
class StatCounter(object): class StatCounter(object):
def __init__(self): def __init__(self):
self.reset() self.reset()
def feed(self, v): def feed(self, v):
self.values.append(v) self._values.append(v)
def reset(self): def reset(self):
self.values = [] self._values = []
@property @property
def count(self): def count(self):
return len(self.values) return len(self._values)
@property @property
def average(self): def average(self):
assert len(self.values) assert len(self._values)
return np.mean(self.values) return np.mean(self._values)
@property @property
def sum(self): def sum(self):
assert len(self.values) assert len(self._values)
return np.sum(self.values) return np.sum(self._values)
@property @property
def max(self): def max(self):
assert len(self.values) assert len(self._values)
return max(self.values) return max(self._values)
class Accuracy(object): class RatioStatistics(object):
def __init__(self): def __init__(self):
self.reset() self.reset()
def reset(self): def reset(self):
self.tot = 0 self._tot = 0
self.corr = 0 self._cnt = 0
def feed(self, corr, tot=1): def feed(self, cnt, tot=1):
self.tot += tot self._tot += tot
self.corr += corr self._cnt += cnt
@property @property
def accuracy(self): def ratio(self):
if self.tot == 0: if self._tot == 0:
return 0 return 0
return self.corr * 1.0 / self.tot return self._cnt * 1.0 / self._tot
@property @property
def count(self): def count(self):
return self.tot return self._tot
class Accuracy(RatioStatistics):
@property
def accuracy(self):
return self.ratio
class BinaryStatistics(object): class BinaryStatistics(object):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment