Commit 9d3cf419 authored by Yuxin Wu's avatar Yuxin Wu

total timer

parent bc1ba816
......@@ -189,7 +189,7 @@ def eval_with_funcs(predict_funcs):
score = play_one_episode(player, self.func)
self.queue_put_stoppable(self.q, score)
q = queue.Queue()
q = queue.Queue(maxsize=3)
threads = [Worker(f, q) for f in predict_funcs]
for k in threads:
......
......@@ -158,6 +158,32 @@ class AtariPlayer(RLEnvironment):
if __name__ == '__main__':
import sys
import time
def benchmark():
a = AtariPlayer(sys.argv[1], viz=False, height_range=(28,-8))
num = a.get_num_actions()
rng = get_rng(num)
start = time.time()
cnt = 0
while True:
act = rng.choice(range(num))
r, o = a.action(act)
a.current_state()
cnt += 1
if cnt == 5000:
break
print time.time() - start
if len(sys.argv) == 3 and sys.argv[2] == 'benchmark':
import threading, multiprocessing
for k in range(3):
#th = multiprocessing.Process(target=benchmark)
th = threading.Thread(target=benchmark)
th.start()
time.sleep(0.02)
benchmark()
else:
a = AtariPlayer(sys.argv[1],
viz=0.03, height_range=(28,-8))
num = a.get_num_actions()
......
......@@ -5,6 +5,7 @@
from ..base import DataFlow
from ...utils import *
from ...utils.timer import *
from six.moves import zip, map
from collections import Counter
import json
......
......@@ -11,6 +11,7 @@ import tqdm
import tensorflow as tf
from .config import TrainConfig
from ..utils import *
from ..utils.timer import *
from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder
from ..tfutils import *
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: timer.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from contextlib import contextmanager
import time
from collections import defaultdict
import six
from .stat import StatCounter
from . import logger
__all__ = ['total_timer', 'timed_operation', 'print_total_timer']
@contextmanager
def timed_operation(msg, log_start=False):
if log_start:
logger.info('start {} ...'.format(msg))
start = time.time()
yield
logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start))
_TOTAL_TIMER_DATA = defaultdict(StatCounter)
@contextmanager
def total_timer(msg):
start = time.time()
yield
t = time.time() - start
_TOTAL_TIMER_DATA[msg].feed(t)
def print_total_timer():
for k, v in six.iteritems(_TOTAL_TIMER_DATA):
logger.info("Total Time: {} -> {} sec".format(k, v.sum))
......@@ -11,7 +11,7 @@ import numpy as np
from . import logger
__all__ = ['timed_operation', 'change_env',
__all__ = ['change_env',
'get_rng', 'memoized', 'get_nr_gpu', 'get_gpus']
#def expand_dim_if_necessary(var, dp):
......@@ -28,15 +28,6 @@ __all__ = ['timed_operation', 'change_env',
# dp = dp.reshape(new_shape)
# return dp
@contextmanager
def timed_operation(msg, log_start=False):
if log_start:
logger.info('start {} ...'.format(msg))
start = time.time()
yield
logger.info('{} finished, time={:.2f}sec.'.format(
msg, time.time() - start))
@contextmanager
def change_env(name, val):
oldval = os.environ.get(name, None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment