Commit 2cd41e99 authored by Yuxin Wu's avatar Yuxin Wu

fix towers

parent adb45736
...@@ -111,7 +111,6 @@ class Model(ModelDesc): ...@@ -111,7 +111,6 @@ class Model(ModelDesc):
predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0) predict_onehot = tf.one_hot(self.greedy_choice, NUM_ACTIONS, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1) best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v) target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
sqrcost = tf.square(target - pred_action_value) sqrcost = tf.square(target - pred_action_value)
......
...@@ -54,6 +54,7 @@ class SaverRestore(SessionInit): ...@@ -54,6 +54,7 @@ class SaverRestore(SessionInit):
def __init__(self, model_path, prefix=None): def __init__(self, model_path, prefix=None):
""" """
:param model_path: a model file or a ``checkpoint`` file. :param model_path: a model file or a ``checkpoint`` file.
:param prefix: add a `prefix/` for every variable in this checkpoint
""" """
assert os.path.isfile(model_path) assert os.path.isfile(model_path)
if os.path.basename(model_path) == 'checkpoint': if os.path.basename(model_path) == 'checkpoint':
......
...@@ -57,6 +57,8 @@ class TrainConfig(object): ...@@ -57,6 +57,8 @@ class TrainConfig(object):
if 'nr_tower' in kwargs or 'tower' in kwargs: if 'nr_tower' in kwargs or 'tower' in kwargs:
self.set_tower(**kwargs) self.set_tower(**kwargs)
else:
self.tower = [0]
self.extra_threads_procs = kwargs.pop('extra_threads_procs', []) self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
...@@ -72,3 +74,11 @@ class TrainConfig(object): ...@@ -72,3 +74,11 @@ class TrainConfig(object):
tower = list(range(tower)) tower = list(range(tower))
self.tower = tower self.tower = tower
assert isinstance(self.tower, list) assert isinstance(self.tower, list)
@property
def nr_tower(self):
return len(self.tower)
@nr_tower.setter
def nr_tower(self, value):
self.tower = list(range(value))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment