Commit 342cdc70 authored by Yuxin Wu's avatar Yuxin Wu

Deprecate huber_loss used in DQN

parent ca278b70
...@@ -369,7 +369,9 @@ def autodoc_skip_member(app, what, name, obj, skip, options): ...@@ -369,7 +369,9 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'GaussianDeform', 'GaussianDeform',
'dump_chkpt_vars', 'dump_chkpt_vars',
'VisualQA', 'VisualQA',
'ParamRestore']: 'ParamRestore',
'huber_loss'
]:
return True return True
if name in ['get_data', 'size', 'reset_state']: if name in ['get_data', 'size', 'reset_state']:
# skip these methods with empty docstring # skip these methods with empty docstring
......
...@@ -5,12 +5,14 @@ ...@@ -5,12 +5,14 @@
import abc import abc
import tensorflow as tf import tensorflow as tf
import tensorpack
from tensorpack import ModelDesc, InputDesc from tensorpack import ModelDesc, InputDesc
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.tfutils import ( from tensorpack.tfutils import (
collection, summary, get_current_tower_context, optimizer, gradproc) collection, summary, get_current_tower_context, optimizer, gradproc)
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
assert tensorpack.tfutils.common.get_tf_version_number() >= 1.2
class Model(ModelDesc): class Model(ModelDesc):
...@@ -72,14 +74,14 @@ class Model(ModelDesc): ...@@ -72,14 +74,14 @@ class Model(ModelDesc):
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * self.gamma * tf.stop_gradient(best_v) target = reward + (1.0 - tf.cast(isOver, tf.float32)) * self.gamma * tf.stop_gradient(best_v)
self.cost = tf.reduce_mean(symbf.huber_loss( self.cost = tf.losses.huber_loss(
target - pred_action_value), name='cost') target, pred_action_value, reduction=tf.losses.Reduction.MEAN)
summary.add_param_summary(('conv.*/W', ['histogram', 'rms']), summary.add_param_summary(('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms'])) # monitor all W ('fc.*/W', ['histogram', 'rms'])) # monitor all W
summary.add_moving_summary(self.cost) summary.add_moving_summary(self.cost)
def _get_optimizer(self): def _get_optimizer(self):
lr = symbf.get_scalar_var('learning_rate', 1e-3, summary=True) lr = tf.get_variable('learning_rate', initializer=1e-3, trainable=False)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3) opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
return optimizer.apply_grad_processors( return optimizer.apply_grad_processors(
opt, [gradproc.GlobalNormClip(10), gradproc.SummaryGradient()]) opt, [gradproc.GlobalNormClip(10), gradproc.SummaryGradient()])
......
...@@ -9,6 +9,11 @@ import os.path ...@@ -9,6 +9,11 @@ import os.path
__all__ = [] __all__ = []
"""
This module should be removed in the future.
"""
def _global_import(name): def _global_import(name):
p = __import__(name, globals(), locals(), level=1) p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p) lst = p.__all__ if '__all__' in dir(p) else dir(p)
......
...@@ -23,7 +23,9 @@ class DataFlowTerminated(BaseException): ...@@ -23,7 +23,9 @@ class DataFlowTerminated(BaseException):
class DataFlowReentrantGuard(object): class DataFlowReentrantGuard(object):
""" """
A tool to enforce thread-level non-reentrancy on DataFlow. A tool to enforce non-reentrancy.
Mostly used on DataFlow whose :meth:`get_data` is stateful,
so that multiple instances of the iterator cannot co-exist.
""" """
def __init__(self): def __init__(self):
self._lock = threading.Lock() self._lock = threading.Lock()
...@@ -31,7 +33,7 @@ class DataFlowReentrantGuard(object): ...@@ -31,7 +33,7 @@ class DataFlowReentrantGuard(object):
def __enter__(self): def __enter__(self):
self._succ = self._lock.acquire(False) self._succ = self._lock.acquire(False)
if not self._succ: if not self._succ:
raise threading.ThreadError("This DataFlow cannot be reused under different threads!") raise threading.ThreadError("This DataFlow is not reentrant!")
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self._lock.release() self._lock.release()
......
...@@ -6,6 +6,8 @@ import tensorflow as tf ...@@ -6,6 +6,8 @@ import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np import numpy as np
from ..utils.develop import deprecated
# __all__ = ['get_scalar_var'] # __all__ = ['get_scalar_var']
...@@ -124,6 +126,7 @@ def rms(x, name=None): ...@@ -124,6 +126,7 @@ def rms(x, name=None):
return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name) return tf.sqrt(tf.reduce_mean(tf.square(x)), name=name)
@deprecated("Please use tf.losses.huber_loss instead!")
def huber_loss(x, delta=1, name='huber_loss'): def huber_loss(x, delta=1, name='huber_loss'):
r""" r"""
Huber loss of x. Huber loss of x.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment