Commit 2cd41e99 authored by Yuxin Wu's avatar Yuxin Wu

fix towers

parent adb45736
......@@ -111,7 +111,6 @@ class Model(ModelDesc):
predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
sqrcost = tf.square(target - pred_action_value)
......
......@@ -54,6 +54,7 @@ class SaverRestore(SessionInit):
def __init__(self, model_path, prefix=None):
"""
:param model_path: a model file or a ``checkpoint`` file.
:param prefix: add a `prefix/` for every variable in this checkpoint
"""
assert os.path.isfile(model_path)
if os.path.basename(model_path) == 'checkpoint':
......
......@@ -57,6 +57,8 @@ class TrainConfig(object):
if 'nr_tower' in kwargs or 'tower' in kwargs:
self.set_tower(**kwargs)
else:
self.tower = [0]
self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
......@@ -72,3 +74,11 @@ class TrainConfig(object):
tower = list(range(tower))
self.tower = tower
assert isinstance(self.tower, list)
@property
def nr_tower(self):
return len(self.tower)
@nr_tower.setter
def nr_tower(self, value):
self.tower = list(range(value))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment