Commit 9850edf5 authored by Yuxin Wu's avatar Yuxin Wu

fix import in DQN. add auto_reuse_vs in DQN.

parent c1f4adaf
...@@ -10,6 +10,7 @@ from tensorpack.utils import logger ...@@ -10,6 +10,7 @@ from tensorpack.utils import logger
from tensorpack.tfutils import ( from tensorpack.tfutils import (
collection, summary, get_current_tower_context, optimizer, gradproc) collection, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
class Model(ModelDesc): class Model(ModelDesc):
...@@ -34,11 +35,16 @@ class Model(ModelDesc): ...@@ -34,11 +35,16 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image): def _get_DQN_prediction(self, image):
pass pass
# decorate the function
@auto_reuse_variable_scope
def get_DQN_prediction(self, image):
return self._get_DQN_prediction(image)
def _build_graph(self, inputs): def _build_graph(self, inputs):
comb_state, action, reward, isOver = inputs comb_state, action, reward, isOver = inputs
comb_state = tf.cast(comb_state, tf.float32) comb_state = tf.cast(comb_state, tf.float32)
state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, self.channel], name='state') state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, self.channel], name='state')
self.predict_value = self._get_DQN_prediction(state) self.predict_value = self.get_DQN_prediction(state)
if not get_current_tower_context().is_training: if not get_current_tower_context().is_training:
return return
...@@ -51,18 +57,15 @@ class Model(ModelDesc): ...@@ -51,18 +57,15 @@ class Model(ModelDesc):
self.predict_value, 1), name='predict_reward') self.predict_value, 1), name='predict_reward')
summary.add_moving_summary(max_pred_reward) summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'), \ with tf.variable_scope('target'):
collection.freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]): targetQ_predict_value = self.get_DQN_prediction(next_state) # NxA
targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA
if self.method != 'Double': if self.method != 'Double':
# DQN # DQN
best_v = tf.reduce_max(targetQ_predict_value, 1) # N, best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
else: else:
# Double-DQN # Double-DQN
sc = tf.get_variable_scope() next_predict_value = self.get_DQN_prediction(next_state)
with tf.variable_scope(sc, reuse=True):
next_predict_value = self._get_DQN_prediction(next_state)
self.greedy_choice = tf.argmax(next_predict_value, 1) # N, self.greedy_choice = tf.argmax(next_predict_value, 1) # N,
predict_onehot = tf.one_hot(self.greedy_choice, self.num_actions, 1.0, 0.0) predict_onehot = tf.one_hot(self.greedy_choice, self.num_actions, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1) best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
......
...@@ -13,6 +13,7 @@ from six.moves import queue ...@@ -13,6 +13,7 @@ from six.moves import queue
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.utils.stats import * from tensorpack.utils.stats import *
from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(player, func, verbose=False): def play_one_episode(player, func, verbose=False):
......
...@@ -26,10 +26,8 @@ def auto_reuse_variable_scope(func): ...@@ -26,10 +26,8 @@ def auto_reuse_variable_scope(func):
h = hash((tf.get_default_graph(), scope.name)) h = hash((tf.get_default_graph(), scope.name))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope)) # print("Entering " + scope.name + " reuse: " + str(h in used_scope))
if h in used_scope: if h in used_scope:
ns = scope.original_name_scope
with tf.variable_scope(scope, reuse=True): with tf.variable_scope(scope, reuse=True):
with tf.name_scope(ns): return func(*args, **kwargs)
return func(*args, **kwargs)
else: else:
used_scope.add(h) used_scope.add(h)
return func(*args, **kwargs) return func(*args, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment