Commit 9850edf5 authored by Yuxin Wu's avatar Yuxin Wu

fix import in DQN. add auto_reuse_vs in DQN.

parent c1f4adaf
......@@ -10,6 +10,7 @@ from tensorpack.utils import logger
from tensorpack.tfutils import (
collection, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
class Model(ModelDesc):
......@@ -34,11 +35,16 @@ class Model(ModelDesc):
def _get_DQN_prediction(self, image):
pass
# decorate the function
@auto_reuse_variable_scope
def get_DQN_prediction(self, image):
return self._get_DQN_prediction(image)
def _build_graph(self, inputs):
comb_state, action, reward, isOver = inputs
comb_state = tf.cast(comb_state, tf.float32)
state = tf.slice(comb_state, [0, 0, 0, 0], [-1, -1, -1, self.channel], name='state')
self.predict_value = self._get_DQN_prediction(state)
self.predict_value = self.get_DQN_prediction(state)
if not get_current_tower_context().is_training:
return
......@@ -51,18 +57,15 @@ class Model(ModelDesc):
self.predict_value, 1), name='predict_reward')
summary.add_moving_summary(max_pred_reward)
with tf.variable_scope('target'), \
collection.freeze_collection([tf.GraphKeys.TRAINABLE_VARIABLES]):
targetQ_predict_value = self._get_DQN_prediction(next_state) # NxA
with tf.variable_scope('target'):
targetQ_predict_value = self.get_DQN_prediction(next_state) # NxA
if self.method != 'Double':
# DQN
best_v = tf.reduce_max(targetQ_predict_value, 1) # N,
else:
# Double-DQN
sc = tf.get_variable_scope()
with tf.variable_scope(sc, reuse=True):
next_predict_value = self._get_DQN_prediction(next_state)
next_predict_value = self.get_DQN_prediction(next_state)
self.greedy_choice = tf.argmax(next_predict_value, 1) # N,
predict_onehot = tf.one_hot(self.greedy_choice, self.num_actions, 1.0, 0.0)
best_v = tf.reduce_sum(targetQ_predict_value * predict_onehot, 1)
......
......@@ -13,6 +13,7 @@ from six.moves import queue
from tensorpack import *
from tensorpack.utils.concurrency import *
from tensorpack.utils.stats import *
from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(player, func, verbose=False):
......
......@@ -26,10 +26,8 @@ def auto_reuse_variable_scope(func):
h = hash((tf.get_default_graph(), scope.name))
# print("Entering " + scope.name + " reuse: " + str(h in used_scope))
if h in used_scope:
ns = scope.original_name_scope
with tf.variable_scope(scope, reuse=True):
with tf.name_scope(ns):
return func(*args, **kwargs)
return func(*args, **kwargs)
else:
used_scope.add(h)
return func(*args, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment