Commit c280473d authored by Yuxin Wu's avatar Yuxin Wu

InferenceRunner select tower from TrainConfig (#249)

parent 9f056711
...@@ -6,18 +6,23 @@ ...@@ -6,18 +6,23 @@
import numpy as np import numpy as np
import os import os
import sys import sys
import re
import time import time
import random import random
import uuid import uuid
import argparse import argparse
import multiprocessing import multiprocessing
import threading import threading
import cv2 import cv2
from collections import deque import tensorflow as tf
import six import six
from six.moves import queue from six.moves import queue
import tensorflow as tf if six.PY3:
from concurrent import futures # py3
CancelledError = futures.CancelledError
else:
CancelledError = Exception
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
...@@ -42,7 +47,7 @@ STEPS_PER_EPOCH = 6000 ...@@ -42,7 +47,7 @@ STEPS_PER_EPOCH = 6000
EVAL_EPISODE = 50 EVAL_EPISODE = 50
BATCH_SIZE = 128 BATCH_SIZE = 128
SIMULATOR_PROC = 50 SIMULATOR_PROC = 50
PREDICTOR_THREAD_PER_GPU = 2 PREDICTOR_THREAD_PER_GPU = 3
PREDICTOR_THREAD = None PREDICTOR_THREAD = None
EVALUATE_PROC = min(multiprocessing.cpu_count() // 2, 20) EVALUATE_PROC = min(multiprocessing.cpu_count() // 2, 20)
...@@ -156,7 +161,11 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -156,7 +161,11 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _on_state(self, state, ident): def _on_state(self, state, ident):
def cb(outputs): def cb(outputs):
try:
distrib, value = outputs.result() distrib, value = outputs.result()
except CancelledError:
logger.info("Client {} cancelled.".format(ident))
return
assert np.all(np.isfinite(distrib)), distrib assert np.all(np.isfinite(distrib)), distrib
action = np.random.choice(len(distrib), p=distrib) action = np.random.choice(len(distrib), p=distrib)
client = self.clients[ident] client = self.clients[ident]
......
...@@ -104,13 +104,15 @@ class InferenceRunnerBase(Callback): ...@@ -104,13 +104,15 @@ class InferenceRunnerBase(Callback):
def _setup_graph(self): def _setup_graph(self):
self._input_data.setup(self.trainer.model) self._input_data.setup(self.trainer.model)
self._setup_input_names() self._setup_input_names()
# Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0]
in_tensors = self._find_input_tensors() in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors assert isinstance(in_tensors, list), in_tensors
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_): def fn(_):
self.trainer.model.build_graph(in_tensors) self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(0) PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._feed_tensors = self._find_feed_tensors() self._feed_tensors = self._find_feed_tensors()
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
...@@ -122,7 +124,7 @@ class InferenceRunnerBase(Callback): ...@@ -122,7 +124,7 @@ class InferenceRunnerBase(Callback):
def _get_tensors_maybe_in_tower(self, names): def _get_tensors_maybe_in_tower(self, names):
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()]) placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, names, 0, prefix=self._prefix) return get_tensor_fn(placeholder_names, names, self._predict_tower_id, prefix=self._prefix)
def _find_input_tensors(self): def _find_input_tensors(self):
pass pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment