Commit 5f98e6ca authored by Yuxin Wu's avatar Yuxin Wu

fix param callback gbu

parent 7da1d899
...@@ -138,6 +138,7 @@ class Model(ModelDesc): ...@@ -138,6 +138,7 @@ class Model(ModelDesc):
def predictor(self, state): def predictor(self, state):
return self.predict_value.eval(feed_dict={'state:0': [state]})[0] return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
#return self.predict_value.eval(feed_dict={'input_deque:0': [state]})[0]
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
...@@ -174,7 +175,7 @@ def get_config(): ...@@ -174,7 +175,7 @@ def get_config():
PeriodicCallback(Evaluator(EVAL_EPISODE, 'fct/output:0'), 2), PeriodicCallback(Evaluator(EVAL_EPISODE, 'fct/output:0'), 2),
]), ]),
# save memory for multiprocess evaluator # save memory for multiprocess evaluator
session_config=get_default_sess_config(0.3), session_config=get_default_sess_config(0.6),
model=M, model=M,
step_per_epoch=STEP_PER_EPOCH, step_per_epoch=STEP_PER_EPOCH,
) )
...@@ -208,4 +209,5 @@ if __name__ == '__main__': ...@@ -208,4 +209,5 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train() SimpleTrainer(config).train()
#QueueInputTrainer(config).train()
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
# File: cifar10-resnet.py # File: cifar10-resnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import tensorflow as tf import tensorflow as tf
import argparse import argparse
import numpy as np
import os import os
from tensorpack.train import TrainConfig, QueueInputTrainer from tensorpack.train import TrainConfig, QueueInputTrainer
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from abc import abstractmethod, ABCMeta, abstractproperty from abc import abstractmethod, ABCMeta, abstractproperty
import operator import operator
import six import six
import os
from .base import Callback from .base import Callback
from ..utils import logger from ..utils import logger
...@@ -29,6 +30,7 @@ class HyperParam(object): ...@@ -29,6 +30,7 @@ class HyperParam(object):
""" define how the value of the param will be set""" """ define how the value of the param will be set"""
pass pass
@property
def readable_name(self): def readable_name(self):
""" A name to display""" """ A name to display"""
return self._readable_name return self._readable_name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment