Commit 7782e724 authored by Yuxin Wu's avatar Yuxin Wu

add SendMonitorData

parent eb57892e
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: imagenet-resnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import cv2
import sys
......@@ -27,6 +26,9 @@ class Model(ModelDesc):
self.data_format = data_format
def _get_inputs(self):
# uint8 instead of float32 is used as input type to reduce copy overhead.
# It might hurt the performance a liiiitle bit.
# The pretrained models were trained with float32.
return [InputDesc(tf.uint8, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
......@@ -197,6 +199,7 @@ def get_config(fake=False, data_format='NCHW'):
dataset_val = get_data('val', fake=fake)
return TrainConfig(
model=Model(data_format=data_format),
dataflow=dataset_train,
callbacks=[
ModelSaver(),
......@@ -207,7 +210,6 @@ def get_config(fake=False, data_format='NCHW'):
[(30, 1e-2), (60, 1e-3), (85, 1e-4), (95, 1e-5)]),
HumanHyperParamSetter('learning_rate'),
],
model=Model(data_format=data_format),
steps_per_epoch=5000,
max_epoch=110,
)
......
......@@ -16,7 +16,7 @@ from ..utils import logger
from .base import Callback
__all__ = ['TrainingMonitor', 'Monitors',
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter']
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter', 'SendMonitorData']
class TrainingMonitor(Callback):
......@@ -262,3 +262,55 @@ class ScalarHistory(TrainingMonitor):
def get_history(self, name):
return self._dic[name]
class SendMonitorData(TrainingMonitor):
"""
Execute a command with some specific scalar monitor data.
This is useful for, e.g. building a custom statistics monitor.
It will try to send once receiving all the stats
"""
def __init__(self, command, names):
"""
Args:
command(str): a command to execute. Use format string with stat
names as keys.
names(list or str): data name(s) to use.
Example:
Send the stats to your phone through pushbullet:
.. code-block:: python
SendMonitorData('curl -u your_id: https://api.pushbullet.com/v2/pushes \\
-d type=note -d title="validation error" \\
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
"""
self.command = command
if not isinstance(names, list):
names = [names]
self.names = names
self.dic = {}
def put_scalar(self, name, val):
if name in self.names:
self.dic[name] = val
def _trigger_step(self):
self._try_send()
def _trigger_epoch(self):
self._try_send()
def _try_send(self):
try:
v = {k: self.dic[k] for k in self.names}
except KeyError:
return
cmd = self.command.format(**v)
ret = os.system(cmd)
if ret != 0:
logger.error("Command {} failed with ret={}!".format(cmd, ret))
self.dic = {}
......@@ -18,37 +18,17 @@ class StatPrinter(Callback):
"2017-05-26")
# TODO make it into monitor?
class SendStat(Callback):
"""
Execute a command with some specific stats.
This is useful for, e.g. building a custom statistics monitor.
"""
def __init__(self, command, stats):
"""
Args:
command(str): a command to execute. Use format string with stat
names as keys.
stats(list or str): stat name(s) to use.
Example:
Send the stats to your phone through pushbullet:
.. code-block:: python
SendStat('curl -u your_id: https://api.pushbullet.com/v2/pushes \\
-d type=note -d title="validation error" \\
-d body={validation_error} > /dev/null 2>&1',
'validation_error')
"""
""" An equivalent of :class:`SendMonitorData`, but as a normal callback. """
def __init__(self, command, names):
self.command = command
if not isinstance(stats, list):
stats = [stats]
self.stats = stats
if not isinstance(names, list):
names = [names]
self.names = names
def _trigger(self):
M = self.trainer.monitors
v = {k: M.get_latest(k) for k in self.stats}
v = {k: M.get_latest(k) for k in self.names}
cmd = self.command.format(**v)
ret = os.system(cmd)
if ret != 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment