Commit 227f8f9e authored by Matthew Hausknecht's avatar Matthew Hausknecht

Added mid-level actions.

parent 9048bd73
...@@ -11,9 +11,9 @@ except: ...@@ -11,9 +11,9 @@ except:
exit() exit()
def get_random_action(): def get_random_action():
""" Returns a random high-level action """ """Returns a random high-level action. Pass is omitted for simplicity."""
high_lv_actions = [HFO_Actions.SHOOT, HFO_Actions.PASS, HFO_Actions.DRIBBLE] high_lv_actions = [HFO_Actions.SHOOT, HFO_Actions.DRIBBLE]
return (random.choice(high_lv_actions), 0, 0) return random.choice(high_lv_actions)
def play_hfo(num): def play_hfo(num):
""" Method called by a thread to play 5 games of HFO """ """ Method called by a thread to play 5 games of HFO """
...@@ -27,7 +27,7 @@ def play_hfo(num): ...@@ -27,7 +27,7 @@ def play_hfo(num):
if state[5] == 1: #state[5] is 1 when player has the ball if state[5] == 1: #state[5] is 1 when player has the ball
status = hfo_env.act(get_random_action()) status = hfo_env.act(get_random_action())
else: else:
status = hfo_env.act((HFO_Actions.MOVE, 0, 0)) status = hfo_env.act(HFO_Actions.MOVE)
except: except:
pass pass
finally: finally:
......
...@@ -22,9 +22,8 @@ int main() { ...@@ -22,9 +22,8 @@ int main() {
// Get the vector of state features for the current state // Get the vector of state features for the current state
const std::vector<float>& feature_vec = hfo.getState(); const std::vector<float>& feature_vec = hfo.getState();
// Create a dash action // Create a dash action
Action a = {DASH, 20.0, 0.0};
// Perform the dash and recieve the current game status // Perform the dash and recieve the current game status
status = hfo.act(a); status = hfo.act(DASH, 20.0);
} }
// Check what the outcome of the episode was // Check what the outcome of the episode was
cout << "Episode " << episode << " ended with status: "; cout << "Episode " << episode << " ended with status: ";
......
...@@ -22,7 +22,7 @@ if __name__ == '__main__': ...@@ -22,7 +22,7 @@ if __name__ == '__main__':
# Grab the state features from the environment # Grab the state features from the environment
features = hfo.getState() features = hfo.getState()
# Take an action and get the current game status # Take an action and get the current game status
status = hfo.act((HFO_Actions.DASH, 0, 0)) status = hfo.act(HFO_Actions.DASH, 20.0, 0)
print 'Episode', episode, 'ended with', print 'Episode', episode, 'ended with',
# Check what the outcome of the episode was # Check what the outcome of the episode was
if status == HFO_Status.GOAL: if status == HFO_Status.GOAL:
......
...@@ -10,10 +10,9 @@ using namespace hfo; ...@@ -10,10 +10,9 @@ using namespace hfo;
// $./bin/HFO --offense-agents 1 // $./bin/HFO --offense-agents 1
// Returns a random high-level action // Returns a random high-level action
Action get_random_high_lv_action() { action_t get_random_high_lv_action() {
action_t action_indx = (action_t) ((rand() % 4) + 4); action_t action_indx = (action_t) ((rand() % 4) + MOVE);
Action act = {action_indx, 0, 0}; return action_indx;
return act;
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
...@@ -32,10 +31,9 @@ int main(int argc, char** argv) { ...@@ -32,10 +31,9 @@ int main(int argc, char** argv) {
while (status == IN_GAME) { while (status == IN_GAME) {
// Get the vector of state features for the current state // Get the vector of state features for the current state
const vector<float>& feature_vec = hfo.getState(); const vector<float>& feature_vec = hfo.getState();
// Create a dash action // Perform the action and recieve the current game status
Action a = get_random_high_lv_action(); status = hfo.act(get_random_high_lv_action());
// Perform the dash and recieve the current game status
status = hfo.act(a);
} }
} }
hfo.act(QUIT);
}; };
...@@ -9,10 +9,12 @@ using namespace hfo; ...@@ -9,10 +9,12 @@ using namespace hfo;
// Before running this program, first Start HFO server: // Before running this program, first Start HFO server:
// $./bin/HFO --offense-agents 1 // $./bin/HFO --offense-agents 1
float arg1, arg2;
// Returns a random low-level action // Returns a random low-level action
Action get_random_low_lv_action() { action_t get_random_low_lv_action() {
action_t action_indx = (action_t) (rand() % 4); action_t action_indx = (action_t) ((rand() % 4) + DASH);
float arg1, arg2; std::cout << action_indx << std::endl;
switch (action_indx) { switch (action_indx) {
case DASH: case DASH:
arg1 = (rand() / float(RAND_MAX)) * 200 - 100; // power: [-100, 100] arg1 = (rand() / float(RAND_MAX)) * 200 - 100; // power: [-100, 100]
...@@ -34,8 +36,7 @@ Action get_random_low_lv_action() { ...@@ -34,8 +36,7 @@ Action get_random_low_lv_action() {
cout << "Invalid Action Index: " << action_indx; cout << "Invalid Action Index: " << action_indx;
break; break;
} }
Action act = {action_indx, arg1, arg2}; return action_indx;
return act;
} }
int main(int argc, char** argv) { int main(int argc, char** argv) {
...@@ -54,10 +55,8 @@ int main(int argc, char** argv) { ...@@ -54,10 +55,8 @@ int main(int argc, char** argv) {
while (status == IN_GAME) { while (status == IN_GAME) {
// Get the vector of state features for the current state // Get the vector of state features for the current state
const vector<float>& feature_vec = hfo.getState(); const vector<float>& feature_vec = hfo.getState();
// Create a dash action // Perform the action and recieve the current game status
Action a = get_random_low_lv_action(); status = hfo.act(get_random_low_lv_action(), arg1, arg2);
// Perform the dash and recieve the current game status
status = hfo.act(a);
} }
} }
}; };
...@@ -14,14 +14,20 @@ class HFO_Actions: ...@@ -14,14 +14,20 @@ class HFO_Actions:
[Low-Level] Turn(direction) [Low-Level] Turn(direction)
[Low-Level] Tackle(direction) [Low-Level] Tackle(direction)
[Low-Level] Kick(power, direction) [Low-Level] Kick(power, direction)
[Mid-Level] Kick_To(target_x, target_y, speed)
[Mid-Level] Move(target_x, target_y)
[Mid-Level] Dribble(target_x, target_y)
[Mid-Level] Intercept(): Intercept the ball
[High-Level] Move(): Reposition player according to strategy [High-Level] Move(): Reposition player according to strategy
[High-Level] Shoot(): Shoot the ball [High-Level] Shoot(): Shoot the ball
[High-Level] Pass(): Pass to the most open teammate [High-Level] Pass(teammate_unum): Pass to teammate
[High-Level] Dribble(): Offensive dribble [High-Level] Dribble(): Offensive dribble
QUIT NOOP(): Do Nothing
QUIT(): Quit the game
''' '''
DASH, TURN, TACKLE, KICK, MOVE, SHOOT, PASS, DRIBBLE, QUIT = range(9) DASH, TURN, TACKLE, KICK, KICK_TO, MOVE_TO, DRIBBLE_TO, INTERCEPT, \
MOVE, SHOOT, PASS, DRIBBLE, NOOP, QUIT = range(14)
class HFO_Status: class HFO_Status:
''' Current status of the HFO game. ''' ''' Current status of the HFO game. '''
...@@ -38,6 +44,24 @@ class HFOEnvironment(object): ...@@ -38,6 +44,24 @@ class HFOEnvironment(object):
self.numFeatures = None # Given by the server in handshake self.numFeatures = None # Given by the server in handshake
self.features = None # The state features self.features = None # The state features
def NumParams(self, action_type):
''' Returns the number of required parameters for each action type. '''
return {
HFO_Actions.DASH : 2,
HFO_Actions.TURN : 1,
HFO_Actions.TACKLE : 1,
HFO_Actions.KICK : 2,
HFO_Actions.KICK_TO : 3,
HFO_Actions.MOVE_TO : 2,
HFO_Actions.DRIBBLE_TO : 2,
HFO_Actions.INTERCEPT : 0,
HFO_Actions.MOVE : 0,
HFO_Actions.SHOOT : 0,
HFO_Actions.PASS : 1,
HFO_Actions.DRIBBLE : 0,
HFO_Actions.NOOP : 0,
HFO_Actions.QUIT : 0}.get(action_type, -1);
def connectToAgentServer(self, server_port=6000, def connectToAgentServer(self, server_port=6000,
feature_set=HFO_Features.HIGH_LEVEL_FEATURE_SET): feature_set=HFO_Features.HIGH_LEVEL_FEATURE_SET):
'''Connect to the server that controls the agent on the specified port. ''' '''Connect to the server that controls the agent on the specified port. '''
...@@ -92,9 +116,14 @@ class HFOEnvironment(object): ...@@ -92,9 +116,14 @@ class HFOEnvironment(object):
size numFeatures. ''' size numFeatures. '''
return self.features return self.features
def act(self, action): def act(self, *args):
''' Send an action and recieve the game status.''' ''' Send an action and recieve the game status.'''
self.socket.send(struct.pack("iff", *action)) assert len(args) > 0, 'Not enough arguments provided to act'
action_type = args[0]
n_params = self.NumParams(action_type)
assert n_params == len(args) - 1, 'Incorrect number of params to act: '\
'Required %d provided %d'%(n_params, len(args)-1)
self.socket.send(struct.pack('i'+'f'*n_params, *args))
# Get the current game status # Get the current game status
data = self.socket.recv(struct.calcsize("i")) data = self.socket.recv(struct.calcsize("i"))
status = struct.unpack("i", data)[0] status = struct.unpack("i", data)[0]
......
...@@ -47,6 +47,41 @@ std::string HFOEnvironment::ActionToString(Action action) { ...@@ -47,6 +47,41 @@ std::string HFOEnvironment::ActionToString(Action action) {
return ss.str(); return ss.str();
}; };
int HFOEnvironment::NumParams(action_t action) {
switch (action) {
case DASH:
return 2;
case TURN:
return 1;
case TACKLE:
return 1;
case KICK:
return 2;
case KICK_TO:
return 3;
case MOVE_TO:
return 2;
case DRIBBLE_TO:
return 2;
case INTERCEPT:
return 0;
case MOVE:
return 0;
case SHOOT:
return 0;
case PASS:
return 1;
case DRIBBLE:
return 0;
case NOOP:
return 0;
case QUIT:
return 0;
}
std::cerr << "Unrecognized Action: " << action;
return -1;
}
bool HFOEnvironment::ParseConfig(const std::string& message, Config& config) { bool HFOEnvironment::ParseConfig(const std::string& message, Config& config) {
config.num_offense = -1; config.num_offense = -1;
config.num_defense = -1; config.num_defense = -1;
...@@ -203,14 +238,31 @@ const std::vector<float>& HFOEnvironment::getState() { ...@@ -203,14 +238,31 @@ const std::vector<float>& HFOEnvironment::getState() {
return feature_vec; return feature_vec;
} }
status_t HFOEnvironment::act(Action action) { status_t HFOEnvironment::act(action_t action, ...) {
status_t game_status; status_t game_status;
// Send the action // Send the action_type
if (send(sockfd, &action, sizeof(Action), 0) < 0) { if (send(sockfd, &action, sizeof(action_t), 0) < 0) {
perror("[Agent Client] ERROR sending from socket"); perror("[Agent Client] ERROR sending from socket");
close(sockfd); close(sockfd);
exit(1); exit(1);
} }
// Send the arguments
int n_args = NumParams(action);
if (n_args > 0) {
float params[n_args];
va_list vl;
va_start(vl, n_args);
for (int i = 0; i < n_args; ++i) {
params[i] = va_arg(vl, double);
}
va_end(vl);
// Send the arguments
if (send(sockfd, &params, sizeof(float) * n_args, 0) < 0) {
perror("[Agent Client] ERROR sending from socket");
close(sockfd);
exit(1);
}
}
// Get the game status // Get the game status
if (recv(sockfd, &game_status, sizeof(status_t), 0) < 0) { if (recv(sockfd, &game_status, sizeof(status_t), 0) < 0) {
perror("[Agent Client] ERROR recieving from socket"); perror("[Agent Client] ERROR recieving from socket");
......
...@@ -17,14 +17,19 @@ enum feature_set_t ...@@ -17,14 +17,19 @@ enum feature_set_t
// The actions available to the agent // The actions available to the agent
enum action_t enum action_t
{ {
DASH, // [Low-Level] Dash(power, relative_direction) DASH, // [Low-Level] Dash(power, direction)
TURN, // [Low-Level] Turn(direction) TURN, // [Low-Level] Turn(direction)
TACKLE, // [Low-Level] Tackle(direction) TACKLE, // [Low-Level] Tackle(direction)
KICK, // [Low-Level] Kick(power, direction) KICK, // [Low-Level] Kick(power, direction)
KICK_TO, // [Mid-Level] Kick_To(target_x, target_y, speed)
MOVE_TO, // [Mid-Level] Move(target_x, target_y)
DRIBBLE_TO, // [Mid-Level] Dribble(target_x, target_y)
INTERCEPT, // [Mid-Level] Intercept(): Intercept the ball
MOVE, // [High-Level] Move(): Reposition player according to strategy MOVE, // [High-Level] Move(): Reposition player according to strategy
SHOOT, // [High-Level] Shoot(): Shoot the ball SHOOT, // [High-Level] Shoot(): Shoot the ball
PASS, // [High-Level] Pass(teammate_unum): Pass to the most open teammate PASS, // [High-Level] Pass(teammate_unum): Pass to the most open teammate
DRIBBLE, // [High-Level] Dribble(): Offensive dribble DRIBBLE, // [High-Level] Dribble(): Offensive dribble
NOOP, // Do nothing
QUIT // Special action to quit the game QUIT // Special action to quit the game
}; };
...@@ -65,6 +70,9 @@ class HFOEnvironment { ...@@ -65,6 +70,9 @@ class HFOEnvironment {
// Returns a string representation of an action. // Returns a string representation of an action.
static std::string ActionToString(Action action); static std::string ActionToString(Action action);
// Get the number of parameters needed for a action.
static int NumParams(action_t action);
// Parse a Trainer message to populate config. Returns a bool // Parse a Trainer message to populate config. Returns a bool
// indicating if the struct was correctly parsed. // indicating if the struct was correctly parsed.
static bool ParseConfig(const std::string& message, Config& config); static bool ParseConfig(const std::string& message, Config& config);
...@@ -77,7 +85,7 @@ class HFOEnvironment { ...@@ -77,7 +85,7 @@ class HFOEnvironment {
const std::vector<float>& getState(); const std::vector<float>& getState();
// Take an action and recieve the resulting game status // Take an action and recieve the resulting game status
status_t act(Action action); status_t act(action_t action, ...);
protected: protected:
int numFeatures; // The number of features in this domain int numFeatures; // The number of features in this domain
......
...@@ -425,38 +425,66 @@ void Agent::actionImpl() { ...@@ -425,38 +425,66 @@ void Agent::actionImpl() {
exit(1); exit(1);
} }
// Get the action // Get the action type
Action action; action_t action;
if (recv(newsockfd, &action, sizeof(Action), 0) < 0) { if (recv(newsockfd, &action, sizeof(action_t), 0) < 0) {
perror("[Agent Server] ERROR recv from socket"); perror("[Agent Server] ERROR recv from socket");
close(sockfd); close(sockfd);
exit(1); exit(1);
} }
if (action.action == SHOOT) { // Get the parameters for that action
int n_args = HFOEnvironment::NumParams(action);
float params[n_args];
if (n_args > 0) {
if (recv(newsockfd, &params, sizeof(float)*n_args, 0) < 0) {
perror("[Agent Server] ERROR recv from socket");
close(sockfd);
exit(1);
}
}
if (action == SHOOT) {
const ShootGenerator::Container & cont = const ShootGenerator::Container & cont =
ShootGenerator::instance().courses(this->world(), false); ShootGenerator::instance().courses(this->world(), false);
ShootGenerator::Container::const_iterator best_shoot ShootGenerator::Container::const_iterator best_shoot
= std::min_element(cont.begin(), cont.end(), ShootGenerator::ScoreCmp()); = std::min_element(cont.begin(), cont.end(), ShootGenerator::ScoreCmp());
Body_SmartKick(best_shoot->target_point_, best_shoot->first_ball_speed_, Body_SmartKick(best_shoot->target_point_, best_shoot->first_ball_speed_,
best_shoot->first_ball_speed_ * 0.99, 3).execute(this); best_shoot->first_ball_speed_ * 0.99, 3).execute(this);
} else if (action.action == PASS) { } else if (action == PASS) {
Force_Pass pass; Force_Pass pass;
int receiver = int(action.arg1); int receiver = int(params[0]);
pass.get_pass_to_player(this->world(), receiver); pass.get_pass_to_player(this->world(), receiver);
pass.execute(this); pass.execute(this);
} }
switch(action.action) { switch(action) {
case DASH: case DASH:
this->doDash(action.arg1, action.arg2); this->doDash(params[0], params[1]);
break; break;
case TURN: case TURN:
this->doTurn(action.arg1); this->doTurn(params[0]);
break; break;
case TACKLE: case TACKLE:
this->doTackle(action.arg1, false); this->doTackle(params[0], false);
break; break;
case KICK: case KICK:
this->doKick(action.arg1, action.arg2); this->doKick(params[0], params[1]);
break;
case KICK_TO:
Body_SmartKick(Vector2D(feature_extractor->absoluteXPos(params[0]),
feature_extractor->absoluteYPos(params[1])),
params[2], params[2] * 0.99, 3).execute(this);
break;
case MOVE_TO:
Body_GoToPoint(Vector2D(feature_extractor->absoluteXPos(params[0]),
feature_extractor->absoluteYPos(params[1])), 0.25,
ServerParam::i().maxDashPower()).execute(this);
break;
case DRIBBLE_TO:
Body_Dribble(Vector2D(feature_extractor->absoluteXPos(params[0]),
feature_extractor->absoluteYPos(params[1])), 1.0,
ServerParam::i().maxDashPower(), 2).execute(this);
break;
case INTERCEPT:
Body_Intercept().execute(this);
break; break;
case MOVE: case MOVE:
this->doMove(); this->doMove();
...@@ -468,13 +496,15 @@ void Agent::actionImpl() { ...@@ -468,13 +496,15 @@ void Agent::actionImpl() {
case DRIBBLE: case DRIBBLE:
this->doDribble(); this->doDribble();
break; break;
case NOOP:
break;
case QUIT: case QUIT:
std::cout << "[Agent Server] Got quit from agent." << std::endl; std::cout << "[Agent Server] Got quit from agent." << std::endl;
close(sockfd); close(sockfd);
exit(0); exit(0);
default: default:
std::cerr << "[Agent Server] ERROR Unsupported Action: " std::cerr << "[Agent Server] ERROR Unsupported Action: "
<< action.action << std::endl; << action << std::endl;
close(sockfd); close(sockfd);
exit(1); exit(1);
} }
......
...@@ -9,8 +9,13 @@ ...@@ -9,8 +9,13 @@
using namespace rcsc; using namespace rcsc;
FeatureExtractor::FeatureExtractor() : FeatureExtractor::FeatureExtractor(int num_teammates,
numFeatures(-1) int num_opponents,
bool playing_offense) :
numFeatures(-1),
numTeammates(num_teammates),
numOpponents(num_opponents),
playingOffense(playing_offense)
{ {
const ServerParam& SP = ServerParam::i(); const ServerParam& SP = ServerParam::i();
...@@ -92,17 +97,35 @@ void FeatureExtractor::addFeature(float val) { ...@@ -92,17 +97,35 @@ void FeatureExtractor::addFeature(float val) {
feature_vec[featIndx++] = val; feature_vec[featIndx++] = val;
} }
void FeatureExtractor::addNormFeature(float val, float min_val, float max_val) { float FeatureExtractor::normalize(float val, float min_val, float max_val) {
assert(featIndx < numFeatures);
if (val < min_val || val > max_val) { if (val < min_val || val > max_val) {
std::cout << "Feature " << featIndx << " Violated Feature Bounds: " << val std::cout << "Feature " << featIndx << " Violated Feature Bounds: " << val
<< " Expected min/max: [" << min_val << ", " << max_val << "]" << std::endl; << " Expected min/max: [" << min_val << ", "
<< max_val << "]" << std::endl;
val = std::min(std::max(val, min_val), max_val); val = std::min(std::max(val, min_val), max_val);
} }
feature_vec[featIndx++] = ((val - min_val) / (max_val - min_val)) return ((val - min_val) / (max_val - min_val))
* (FEAT_MAX - FEAT_MIN) + FEAT_MIN; * (FEAT_MAX - FEAT_MIN) + FEAT_MIN;
} }
float FeatureExtractor::unnormalize(float val, float min_val, float max_val) {
if (val < FEAT_MIN || val > FEAT_MAX) {
std::cout << "Unnormalized value Violated Feature Bounds: " << val
<< " Expected min/max: [" << FEAT_MIN << ", "
<< FEAT_MAX << "]" << std::endl;
float ft_max = FEAT_MAX; // Linker error on OSX otherwise...?
float ft_min = FEAT_MIN;
val = std::min(std::max(val, ft_min), ft_max);
}
return ((val - FEAT_MIN) / (FEAT_MAX - FEAT_MIN))
* (max_val - min_val) + min_val;
}
void FeatureExtractor::addNormFeature(float val, float min_val, float max_val) {
assert(featIndx < numFeatures);
feature_vec[featIndx++] = normalize(val, min_val, max_val);
}
void FeatureExtractor::checkFeatures() { void FeatureExtractor::checkFeatures() {
assert(feature_vec.size() == numFeatures); assert(feature_vec.size() == numFeatures);
for (int i=0; i<numFeatures; ++i) { for (int i=0; i<numFeatures; ++i) {
...@@ -241,3 +264,33 @@ void FeatureExtractor::splitAngles(std::vector<OpenAngle> &openAngles, ...@@ -241,3 +264,33 @@ void FeatureExtractor::splitAngles(std::vector<OpenAngle> &openAngles,
} }
openAngles = resAngles; openAngles = resAngles;
} }
float FeatureExtractor::normalizedXPos(float absolute_x_pos) {
float tolerance_x = .1 * pitchHalfLength;
if (playingOffense) {
return normalize(absolute_x_pos, -tolerance_x, pitchHalfLength + tolerance_x);
} else {
return normalize(absolute_x_pos, -pitchHalfLength-tolerance_x, tolerance_x);
}
}
float FeatureExtractor::normalizedYPos(float absolute_y_pos) {
float tolerance_y = .1 * pitchHalfWidth;
return normalize(absolute_y_pos, -pitchHalfWidth - tolerance_y,
pitchHalfWidth + tolerance_y);
}
float FeatureExtractor::absoluteXPos(float normalized_x_pos) {
float tolerance_x = .1 * pitchHalfLength;
if (playingOffense) {
return unnormalize(normalized_x_pos, -tolerance_x, pitchHalfLength + tolerance_x);
} else {
return unnormalize(normalized_x_pos, -pitchHalfLength-tolerance_x, tolerance_x);
}
}
float FeatureExtractor::absoluteYPos(float normalized_y_pos) {
float tolerance_y = .1 * pitchHalfWidth;
return unnormalize(normalized_y_pos, -pitchHalfWidth - tolerance_y,
pitchHalfWidth + tolerance_y);
}
...@@ -8,7 +8,7 @@ typedef std::pair<float, float> OpenAngle; ...@@ -8,7 +8,7 @@ typedef std::pair<float, float> OpenAngle;
class FeatureExtractor { class FeatureExtractor {
public: public:
FeatureExtractor(); FeatureExtractor(int num_teammates, int num_opponents, bool playing_offense);
virtual ~FeatureExtractor(); virtual ~FeatureExtractor();
// Updated the state features stored in feature_vec // Updated the state features stored in feature_vec
...@@ -68,6 +68,12 @@ public: ...@@ -68,6 +68,12 @@ public:
float oppAngleBottom, float oppAngleBottom,
float oppAngleTop); float oppAngleTop);
// Convert back and forth between normalized and absolute x,y postions
float normalizedXPos(float absolute_x_pos);
float normalizedYPos(float absolute_y_pos);
float absoluteXPos(float normalized_x_pos);
float absoluteYPos(float normalized_y_pos);
protected: protected:
// Encodes an angle feature as the sin and cosine of that angle, // Encodes an angle feature as the sin and cosine of that angle,
// effectively transforming a single angle into two features. // effectively transforming a single angle into two features.
...@@ -92,6 +98,11 @@ protected: ...@@ -92,6 +98,11 @@ protected:
// Add a feature without normalizing // Add a feature without normalizing
void addFeature(float val); void addFeature(float val);
// Returns a normalized feature value
float normalize(float val, float min_val, float max_val);
// Converts a normalized feature value back into original space
float unnormalize(float val, float min_val, float max_val);
// Add a feature and normalize to the range [FEAT_MIN, FEAT_MAX] // Add a feature and normalize to the range [FEAT_MIN, FEAT_MAX]
void addNormFeature(float val, float min_val, float max_val); void addNormFeature(float val, float min_val, float max_val);
...@@ -118,6 +129,9 @@ protected: ...@@ -118,6 +129,9 @@ protected:
// Useful measures defined by the Server Parameters // Useful measures defined by the Server Parameters
float pitchLength, pitchWidth, pitchHalfLength, pitchHalfWidth, float pitchLength, pitchWidth, pitchHalfLength, pitchHalfWidth,
goalHalfWidth, penaltyAreaLength, penaltyAreaWidth; goalHalfWidth, penaltyAreaLength, penaltyAreaWidth;
int numTeammates; // Number of teammates in HFO
int numOpponents; // Number of opponents in HFO
bool playingOffense; // Are we playing offense or defense?
}; };
#endif // FEATURE_EXTRACTOR_H #endif // FEATURE_EXTRACTOR_H
...@@ -10,10 +10,7 @@ using namespace rcsc; ...@@ -10,10 +10,7 @@ using namespace rcsc;
HighLevelFeatureExtractor::HighLevelFeatureExtractor(int num_teammates, HighLevelFeatureExtractor::HighLevelFeatureExtractor(int num_teammates,
int num_opponents, int num_opponents,
bool playing_offense) : bool playing_offense) :
FeatureExtractor(), FeatureExtractor(num_teammates, num_opponents, playing_offense)
numTeammates(num_teammates),
numOpponents(num_opponents),
playingOffense(playing_offense)
{ {
assert(numTeammates >= 0); assert(numTeammates >= 0);
assert(numOpponents >= 0); assert(numOpponents >= 0);
......
...@@ -25,9 +25,6 @@ protected: ...@@ -25,9 +25,6 @@ protected:
const static int num_basic_features = 9; const static int num_basic_features = 9;
// Number of features for each player or opponent in game. // Number of features for each player or opponent in game.
const static int features_per_teammate = 5; const static int features_per_teammate = 5;
int numTeammates; // Number of teammates in HFO
int numOpponents; // Number of opponents in HFO
bool playingOffense; // Are we playing offense or defense?
}; };
#endif // HIGHLEVEL_FEATURE_EXTRACTOR_H #endif // HIGHLEVEL_FEATURE_EXTRACTOR_H
...@@ -10,10 +10,7 @@ using namespace rcsc; ...@@ -10,10 +10,7 @@ using namespace rcsc;
LowLevelFeatureExtractor::LowLevelFeatureExtractor(int num_teammates, LowLevelFeatureExtractor::LowLevelFeatureExtractor(int num_teammates,
int num_opponents, int num_opponents,
bool playing_offense) : bool playing_offense) :
FeatureExtractor(), FeatureExtractor(num_teammates, num_opponents, playing_offense)
numTeammates(num_teammates),
numOpponents(num_opponents),
playingOffense(playing_offense)
{ {
assert(numTeammates >= 0); assert(numTeammates >= 0);
assert(numOpponents >= 0); assert(numOpponents >= 0);
......
...@@ -19,9 +19,6 @@ protected: ...@@ -19,9 +19,6 @@ protected:
const static int num_basic_features = 58; const static int num_basic_features = 58;
// Number of features for each player or opponent in game. // Number of features for each player or opponent in game.
const static int features_per_player = 8; const static int features_per_player = 8;
int numTeammates; // Number of teammates in HFO
int numOpponents; // Number of opponents in HFO
bool playingOffense; // Are we playing offense or defense?
}; };
#endif // LOWLEVEL_FEATURE_EXTRACTOR_H #endif // LOWLEVEL_FEATURE_EXTRACTOR_H
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment