Commit c280473d authored by Yuxin Wu's avatar Yuxin Wu

InferenceRunner select tower from TrainConfig (#249)

parent 9f056711
......@@ -6,18 +6,23 @@
import numpy as np
import os
import sys
import re
import time
import random
import uuid
import argparse
import multiprocessing
import threading
import cv2
from collections import deque
import tensorflow as tf
import six
from six.moves import queue
import tensorflow as tf
if six.PY3:
from concurrent import futures # py3
CancelledError = futures.CancelledError
else:
CancelledError = Exception
from tensorpack import *
from tensorpack.utils.concurrency import *
......@@ -42,7 +47,7 @@ STEPS_PER_EPOCH = 6000
EVAL_EPISODE = 50
BATCH_SIZE = 128
SIMULATOR_PROC = 50
PREDICTOR_THREAD_PER_GPU = 2
PREDICTOR_THREAD_PER_GPU = 3
PREDICTOR_THREAD = None
EVALUATE_PROC = min(multiprocessing.cpu_count() // 2, 20)
......@@ -156,7 +161,11 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _on_state(self, state, ident):
def cb(outputs):
distrib, value = outputs.result()
try:
distrib, value = outputs.result()
except CancelledError:
logger.info("Client {} cancelled.".format(ident))
return
assert np.all(np.isfinite(distrib)), distrib
action = np.random.choice(len(distrib), p=distrib)
client = self.clients[ident]
......
......@@ -104,13 +104,15 @@ class InferenceRunnerBase(Callback):
def _setup_graph(self):
self._input_data.setup(self.trainer.model)
self._setup_input_names()
# Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0]
in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(0)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._feed_tensors = self._find_feed_tensors()
self._hooks = [self._build_hook(inf) for inf in self.infs]
......@@ -122,7 +124,7 @@ class InferenceRunnerBase(Callback):
def _get_tensors_maybe_in_tower(self, names):
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, names, 0, prefix=self._prefix)
return get_tensor_fn(placeholder_names, names, self._predict_tower_id, prefix=self._prefix)
def _find_input_tensors(self):
pass
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment