Commit 5f98e6ca authored by Yuxin Wu's avatar Yuxin Wu

fix param callback gbu

parent 7da1d899
......@@ -138,6 +138,7 @@ class Model(ModelDesc):
def predictor(self, state):
return self.predict_value.eval(feed_dict={'state:0': [state]})[0]
#return self.predict_value.eval(feed_dict={'input_deque:0': [state]})[0]
def get_config():
basename = os.path.basename(__file__)
......@@ -174,7 +175,7 @@ def get_config():
PeriodicCallback(Evaluator(EVAL_EPISODE, 'fct/output:0'), 2),
]),
# save memory for multiprocess evaluator
session_config=get_default_sess_config(0.3),
session_config=get_default_sess_config(0.6),
model=M,
step_per_epoch=STEP_PER_EPOCH,
)
......@@ -208,4 +209,5 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train()
#QueueInputTrainer(config).train()
......@@ -3,9 +3,9 @@
# File: cifar10-resnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import tensorflow as tf
import argparse
import numpy as np
import os
from tensorpack.train import TrainConfig, QueueInputTrainer
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
from abc import abstractmethod, ABCMeta, abstractproperty
import operator
import six
import os
from .base import Callback
from ..utils import logger
......@@ -29,6 +30,7 @@ class HyperParam(object):
""" define how the value of the param will be set"""
pass
@property
def readable_name(self):
""" A name to display"""
return self._readable_name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment