Commit ed4c009f authored by Matthew Hausknecht's avatar Matthew Hausknecht

Refactored features. Added higher level actions.

parent 57f30d7f
...@@ -36,6 +36,7 @@ class Trainer(object): ...@@ -36,6 +36,7 @@ class Trainer(object):
self._serverPort = args.port + 1 # The port the server is listening on self._serverPort = args.port + 1 # The port the server is listening on
self._coachPort = args.port + 2 # The coach port to talk with the server self._coachPort = args.port + 2 # The coach port to talk with the server
self._logDir = args.logDir # Directory to store logs self._logDir = args.logDir # Directory to store logs
self._record = args.record # Record states + actions
self._numOffense = args.numOffense # Number offensive players self._numOffense = args.numOffense # Number offensive players
self._numDefense = args.numDefense # Number defensive players self._numDefense = args.numDefense # Number defensive players
self._maxTrials = args.numTrials # Maximum number of trials to play self._maxTrials = args.numTrials # Maximum number of trials to play
...@@ -67,6 +68,7 @@ class Trainer(object): ...@@ -67,6 +68,7 @@ class Trainer(object):
self._agentNumInt = -1 # Agent's internal team number self._agentNumInt = -1 # Agent's internal team number
self._agentNumExt = -1 # Agent's external team number self._agentNumExt = -1 # Agent's external team number
self._agentServerPort = args.port # Port for agent's server self._agentServerPort = args.port # Port for agent's server
self._agentOnBall = args.agent_on_ball # If true, agent starts with the ball
# =============== MISC =============== # # =============== MISC =============== #
self._offenseTeam = '' # Name of the offensive team self._offenseTeam = '' # Name of the offensive team
self._defenseTeam = '' # Name of the defensive team self._defenseTeam = '' # Name of the defensive team
...@@ -107,6 +109,8 @@ class Trainer(object): ...@@ -107,6 +109,8 @@ class Trainer(object):
%(self._agentTeam, self._agentNumExt, self._serverPort, %(self._agentTeam, self._agentNumExt, self._serverPort,
self._coachPort, self._logDir, numTeammates, numOpponents, self._coachPort, self._logDir, numTeammates, numOpponents,
self._agent_play_offense, self._agentServerPort) self._agent_play_offense, self._agentServerPort)
if self._record:
agentCmd += ' --record'
agentCmd = os.path.join(binary_dir, agentCmd) agentCmd = os.path.join(binary_dir, agentCmd)
agentCmd = agentCmd.split(' ') agentCmd = agentCmd.split(' ')
# Ignore stderr because librcsc continually prints to it # Ignore stderr because librcsc continually prints to it
...@@ -527,6 +531,9 @@ class Trainer(object): ...@@ -527,6 +531,9 @@ class Trainer(object):
# Move the rest of the offense # Move the rest of the offense
for i in xrange(1, self._numOffense + 1): for i in xrange(1, self._numOffense + 1):
self.movePlayer(self._offenseTeam, i, self.getOffensiveResetPosition()) self.movePlayer(self._offenseTeam, i, self.getOffensiveResetPosition())
# Move the agent to the ball
if self._agent and self._agentOnBall:
self.movePlayer(self._offenseTeam, 1, self._ballPosition)
# Move the defensive goalie # Move the defensive goalie
if self._numDefense > 0: if self._numDefense > 0:
self.movePlayer(self._defenseTeam, 0, [0.5 * self.PITCH_LENGTH,0]) self.movePlayer(self._defenseTeam, 0, [0.5 * self.PITCH_LENGTH,0])
......
...@@ -11,10 +11,13 @@ SERVER_CMD = 'rcssserver' ...@@ -11,10 +11,13 @@ SERVER_CMD = 'rcssserver'
# Command to run the monitor. Edit as needed. # Command to run the monitor. Edit as needed.
MONITOR_CMD = 'rcssmonitor' MONITOR_CMD = 'rcssmonitor'
def getAgentDirCmd(binary_dir, teamname, server_port=6000, coach_port=6002, logDir='/tmp'): def getAgentDirCmd(binary_dir, teamname, server_port=6000, coach_port=6002,
logDir='log', record=False):
""" Returns the team name, command, and directory to run a team. """ """ Returns the team name, command, and directory to run a team. """
cmd = 'start.sh -t %s -p %i -P %i --log-dir %s'%(teamname, server_port, cmd = 'start.sh -t %s -p %i -P %i --log-dir %s'%(teamname, server_port,
coach_port, logDir) coach_port, logDir)
if record:
cmd += ' --record'
cmd = os.path.join(binary_dir, cmd) cmd = os.path.join(binary_dir, cmd)
return teamname, cmd return teamname, cmd
...@@ -53,8 +56,10 @@ def main(args, team1='left', team2='right', rng=numpy.random.RandomState()): ...@@ -53,8 +56,10 @@ def main(args, team1='left', team2='right', rng=numpy.random.RandomState()):
args.logDir, args.logDir) args.logDir, args.logDir)
if args.sync: if args.sync:
serverOptions += ' server::synch_mode=on' serverOptions += ' server::synch_mode=on'
team1, team1Cmd = getAgentDirCmd(binary_dir, team1, server_port, coach_port, args.logDir) team1, team1Cmd = getAgentDirCmd(binary_dir, team1, server_port, coach_port,
team2, team2Cmd = getAgentDirCmd(binary_dir, team2, server_port, coach_port, args.logDir) args.logDir, args.record)
team2, team2Cmd = getAgentDirCmd(binary_dir, team2, server_port, coach_port,
args.logDir, args.record)
try: try:
# Launch the Server # Launch the Server
server = launch(SERVER_CMD + serverOptions, name='server') server = launch(SERVER_CMD + serverOptions, name='server')
...@@ -114,6 +119,10 @@ def parseArgs(args=None): ...@@ -114,6 +119,10 @@ def parseArgs(args=None):
' will be incrementally allocated the following ports.') ' will be incrementally allocated the following ports.')
p.add_argument('--log-dir', dest='logDir', default='log/', p.add_argument('--log-dir', dest='logDir', default='log/',
help='Directory to store logs.') help='Directory to store logs.')
p.add_argument('--record', dest='record', action='store_true',
help='Record logs of states and actions.')
p.add_argument('--agent-on-ball', dest='agent_on_ball', action='store_true',
help='Agent starts with the ball.')
return p.parse_args(args=args) return p.parse_args(args=args)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -66,6 +66,7 @@ usage() ...@@ -66,6 +66,7 @@ usage()
echo " --log-dir DIRECTORY specifies debug log directory (default: /tmp)" echo " --log-dir DIRECTORY specifies debug log directory (default: /tmp)"
echo " --debug-log-ext EXTENSION specifies debug log file extension (default: .log)" echo " --debug-log-ext EXTENSION specifies debug log file extension (default: .log)"
echo " --fullstate FULLSTATE_TYPE specifies fullstate model handling" echo " --fullstate FULLSTATE_TYPE specifies fullstate model handling"
echo " --record records actions (default: off)"
echo " FULLSTATE_TYPE is one of [ignore|reference|override].") 1>&2 echo " FULLSTATE_TYPE is one of [ignore|reference|override].") 1>&2
} }
...@@ -162,6 +163,10 @@ do ...@@ -162,6 +163,10 @@ do
offline_mode="on" offline_mode="on"
;; ;;
--record)
record="--record"
;;
--debug) --debug)
debugopt="${debugopt} --debug" debugopt="${debugopt} --debug"
coachdebug="${coachdebug} --debug" coachdebug="${coachdebug} --debug"
...@@ -277,6 +282,7 @@ opt="${opt} --debug_server_host ${debug_server_host}" ...@@ -277,6 +282,7 @@ opt="${opt} --debug_server_host ${debug_server_host}"
opt="${opt} --debug_server_port ${debug_server_port}" opt="${opt} --debug_server_port ${debug_server_port}"
opt="${opt} ${offline_logging}" opt="${opt} ${offline_logging}"
opt="${opt} ${debugopt}" opt="${opt} ${debugopt}"
opt="${opt} ${record}"
ping -c 1 $host ping -c 1 $host
......
...@@ -64,6 +64,7 @@ usage() ...@@ -64,6 +64,7 @@ usage()
echo " --log-dir DIRECTORY specifies debug log directory (default: /tmp)" echo " --log-dir DIRECTORY specifies debug log directory (default: /tmp)"
echo " --debug-log-ext EXTENSION specifies debug log file extension (default: .log)" echo " --debug-log-ext EXTENSION specifies debug log file extension (default: .log)"
echo " --fullstate FULLSTATE_TYPE specifies fullstate model handling" echo " --fullstate FULLSTATE_TYPE specifies fullstate model handling"
echo " --record records actions (default: off)"
echo " FULLSTATE_TYPE is one of [ignore|reference|override]." echo " FULLSTATE_TYPE is one of [ignore|reference|override]."
echo " --offensePlayers player1 ... specifies the numbers of the offense players" echo " --offensePlayers player1 ... specifies the numbers of the offense players"
echo " --defensePlayers player1 ... specifies the numbers of the defense players" echo " --defensePlayers player1 ... specifies the numbers of the defense players"
...@@ -176,6 +177,10 @@ do ...@@ -176,6 +177,10 @@ do
offline_mode="on" offline_mode="on"
;; ;;
--record)
record="--record"
;;
--record_stats_file) --record_stats_file)
if [ $# -lt 2 ]; then if [ $# -lt 2 ]; then
usage usage
...@@ -202,7 +207,7 @@ do ...@@ -202,7 +207,7 @@ do
opts="${opts} --learn-index ${2}" opts="${opts} --learn-index ${2}"
shift 1 shift 1
;; ;;
--learn-path) --learn-path)
if [ $# -lt 2 ]; then if [ $# -lt 2 ]; then
usage usage
...@@ -211,7 +216,7 @@ do ...@@ -211,7 +216,7 @@ do
opts="${opts} --learn-path ${2}" opts="${opts} --learn-path ${2}"
shift 1 shift 1
;; ;;
--model-path) --model-path)
if [ $# -lt 2 ]; then if [ $# -lt 2 ]; then
usage usage
...@@ -283,7 +288,7 @@ do ...@@ -283,7 +288,7 @@ do
opts="${opts} --seed ${2}" opts="${opts} --seed ${2}"
shift 1 shift 1
;; ;;
--trainer) --trainer)
if [ $# -lt 2 ]; then if [ $# -lt 2 ]; then
usage usage
...@@ -301,11 +306,11 @@ do ...@@ -301,11 +306,11 @@ do
opts="${opts} --save-path ${2}" opts="${opts} --save-path ${2}"
shift 1 shift 1
;; ;;
--gdb) --gdb)
use_gdb="true" use_gdb="true"
;; ;;
--run-debug-version) --run-debug-version)
run_debug_version="true" run_debug_version="true"
;; ;;
...@@ -426,6 +431,7 @@ opt="${opt} --debug_server_host ${debug_server_host}" ...@@ -426,6 +431,7 @@ opt="${opt} --debug_server_host ${debug_server_host}"
opt="${opt} --debug_server_port ${debug_server_port}" opt="${opt} --debug_server_port ${debug_server_port}"
opt="${opt} ${offline_logging}" opt="${opt} ${offline_logging}"
opt="${opt} ${debugopt}" opt="${opt} ${debugopt}"
opt="${opt} ${record}"
ping -c 1 $host ping -c 1 $host
......
...@@ -3,14 +3,18 @@ import socket, struct, thread, time ...@@ -3,14 +3,18 @@ import socket, struct, thread, time
class HFO_Actions: class HFO_Actions:
''' An enum of the possible HFO actions ''' An enum of the possible HFO actions
Dash(power, relative_direction) [Low-Level] Dash(power, relative_direction)
Turn(direction) [Low-Level] Turn(direction)
Tackle(direction) [Low-Level] Tackle(direction)
Kick(power, direction) [Low-Level] Kick(power, direction)
[High-Level] Move(): Reposition player according to strategy
[High-Level] Shoot(): Shoot the ball
[High-Level] Pass(): Pass to the most open teammate
[High-Level] Dribble(): Offensive dribble
QUIT QUIT
''' '''
DASH, TURN, TACKLE, KICK, QUIT = range(5) DASH, TURN, TACKLE, KICK, MOVE, SHOOT, PASS, DRIBBLE, QUIT = range(8)
class HFO_Status: class HFO_Status:
''' Current status of the HFO game. ''' ''' Current status of the HFO game. '''
......
...@@ -6,11 +6,15 @@ ...@@ -6,11 +6,15 @@
// The actions available to the agent // The actions available to the agent
enum action_t enum action_t
{ {
DASH, // Dash(power, relative_direction) DASH, // [Low-Level] Dash(power, relative_direction)
TURN, // Turn(direction) TURN, // [Low-Level] Turn(direction)
TACKLE, // Tackle(direction) TACKLE, // [Low-Level] Tackle(direction)
KICK, // Kick(power, direction) KICK, // [Low-Level] Kick(power, direction)
QUIT // Special action to quit the game MOVE, // [High-Level] Move(): Reposition player according to strategy
SHOOT, // [High-Level] Shoot(): Shoot the ball
PASS, // [High-Level] Pass(): Pass to the most open teammate
DRIBBLE, // [High-Level] Dribble(): Offensive dribble
QUIT // Special action to quit the game
}; };
// The current status of the HFO game // The current status of the HFO game
......
This diff is collapsed.
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "field_evaluator.h" #include "field_evaluator.h"
#include "communication.h" #include "communication.h"
#include "HFO.hpp" #include "HFO.hpp"
#include "feature_extractor.h"
#include <rcsc/player/player_agent.h> #include <rcsc/player/player_agent.h>
#include <vector> #include <vector>
...@@ -59,72 +60,29 @@ protected: ...@@ -59,72 +60,29 @@ protected:
virtual FieldEvaluator::ConstPtr createFieldEvaluator() const; virtual FieldEvaluator::ConstPtr createFieldEvaluator() const;
virtual ActionGenerator::ConstPtr createActionGenerator() const; virtual ActionGenerator::ConstPtr createActionGenerator() const;
// Updated the state features stored in feature_vec
void updateStateFeatures();
// Get the current game status // Get the current game status
hfo_status_t getGameStatus(); hfo_status_t getGameStatus();
// Encodes an angle feature as the sin and cosine of that angle,
// effectively transforming a single angle into two features.
void addAngFeature(const rcsc::AngleDeg& ang);
// Encodes a proximity feature which is defined by a distance as
// well as a maximum possible distance, which acts as a
// normalizer. Encodes the distance as [0-far, 1-close]. Ignores
// distances greater than maxDist or less than 0.
void addDistFeature(float dist, float maxDist);
// Add the angle and distance to the landmark to the feature_vec
void addLandmarkFeatures(const rcsc::Vector2D& landmark,
const rcsc::Vector2D& self_pos,
const rcsc::AngleDeg& self_ang);
// Add features corresponding to another player.
void addPlayerFeatures(rcsc::PlayerObject& player,
const rcsc::Vector2D& self_pos,
const rcsc::AngleDeg& self_ang);
// Start the server and listen for a connection. // Start the server and listen for a connection.
void startServer(int server_port=6008); void startServer(int server_port=6008);
// Transmit information to the client and ensure it can recieve. // Transmit information to the client and ensure it can recieve.
void clientHandshake(); void clientHandshake();
// Add a feature without normalizing
void addFeature(float val);
// Add a feature and normalize to the range [FEAT_MIN, FEAT_MAX]
void addNormFeature(float val, float min_val, float max_val);
protected: protected:
int numTeammates; // Number of teammates in HFO FeatureExtractor* feature_extractor;
int numOpponents; // Number of opponents in HFO
bool playingOffense; // Are we playing offense or defense?
int numFeatures; // Total number of features
// Number of features for non-player objects.
const static int num_basic_features = 58;
// Number of features for each player or opponent in game.
const static int features_per_player = 8;
int featIndx; // Feature being populated
std::vector<float> feature_vec; // Contains the current features
// Observed values of some parameters.
const static float observedSelfSpeedMax = 0.46;
const static float observedPlayerSpeedMax = 0.75;
const static float observedStaminaMax = 8000.;
const static float observedBallSpeedMax = 5.0;
float maxHFORadius; // Maximum possible distance in HFO playable region
// Useful measures defined by the Server Parameters
float pitchLength, pitchWidth, pitchHalfLength, pitchHalfWidth,
goalHalfWidth, penaltyAreaLength, penaltyAreaWidth;
long lastTrainerMessageTime; // Last time the trainer sent a message long lastTrainerMessageTime; // Last time the trainer sent a message
int server_port; // Port to start the server on int server_port; // Port to start the server on
bool server_running; // Is the server running? bool server_running; // Is the server running?
int sockfd, newsockfd; // Server sockets int sockfd, newsockfd; // Server sockets
bool record; // Record states + actions
private: private:
bool doPreprocess(); bool doPreprocess();
bool doShoot(); bool doShoot();
bool doPass();
bool doDribble();
bool doMove();
bool doForceKick(); bool doForceKick();
bool doHeardPassReceive(); bool doHeardPassReceive();
......
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif
#include "highlevel_feature_extractor.h"
#include <rcsc/common/server_param.h>
using namespace rcsc;
HighLevelFeatureExtractor::HighLevelFeatureExtractor(int num_teammates,
int num_opponents,
bool playing_offense) :
FeatureExtractor(),
numTeammates(num_teammates),
numOpponents(num_opponents),
playingOffense(playing_offense)
{
assert(numTeammates >= 0);
assert(numOpponents >= 0);
numFeatures = num_basic_features +
features_per_player * (numTeammates + numOpponents);
feature_vec.resize(numFeatures);
}
HighLevelFeatureExtractor::~HighLevelFeatureExtractor() {}
const std::vector<float>& HighLevelFeatureExtractor::ExtractFeatures(
const WorldModel& wm) {
featIndx = 0;
const ServerParam& SP = ServerParam::i();
// ======================== SELF FEATURES ======================== //
const SelfObject& self = wm.self();
const Vector2D& self_pos = self.pos();
const AngleDeg& self_ang = self.body();
addFeature(self.posValid() ? FEAT_MAX : FEAT_MIN);
// ADD_FEATURE(self_pos.x);
// ADD_FEATURE(self_pos.y);
// Direction and speed of the agent.
addFeature(self.velValid() ? FEAT_MAX : FEAT_MIN);
if (self.velValid()) {
addAngFeature(self_ang - self.vel().th());
addNormFeature(self.speed(), 0., observedSelfSpeedMax);
} else {
addFeature(0);
addFeature(0);
addFeature(0);
}
// Global Body Angle -- 0:right -90:up 90:down 180/-180:left
addAngFeature(self_ang);
// Neck Angle -- We probably don't need this unless we are
// controlling the neck manually.
// std::cout << "Face Error: " << self.faceError() << std::endl;
// if (self.faceValid()) {
// std::cout << "FaceAngle: " << self.face() << std::endl;
// }
addNormFeature(self.stamina(), 0., observedStaminaMax);
addFeature(self.isFrozen() ? FEAT_MAX : FEAT_MIN);
// Probabilities - Do we want these???
// std::cout << "catchProb: " << self.catchProbability() << std::endl;
// std::cout << "tackleProb: " << self.tackleProbability() << std::endl;
// std::cout << "fouldProb: " << self.foulProbability() << std::endl;
// Features indicating if we are colliding with an object
addFeature(self.collidesWithBall() ? FEAT_MAX : FEAT_MIN);
addFeature(self.collidesWithPlayer() ? FEAT_MAX : FEAT_MIN);
addFeature(self.collidesWithPost() ? FEAT_MAX : FEAT_MIN);
addFeature(self.isKickable() ? FEAT_MAX : FEAT_MIN);
// inertiaPoint estimates the ball point after a number of steps
// self.inertiaPoint(n_steps);
// ======================== LANDMARK FEATURES ======================== //
// Top Bottom Center of Goal
rcsc::Vector2D goalCenter(pitchHalfLength, 0);
addLandmarkFeatures(goalCenter, self_pos, self_ang);
rcsc::Vector2D goalPostTop(pitchHalfLength, -goalHalfWidth);
addLandmarkFeatures(goalPostTop, self_pos, self_ang);
rcsc::Vector2D goalPostBot(pitchHalfLength, goalHalfWidth);
addLandmarkFeatures(goalPostBot, self_pos, self_ang);
// Top Bottom Center of Penalty Box
rcsc::Vector2D penaltyBoxCenter(pitchHalfLength - penaltyAreaLength, 0);
addLandmarkFeatures(penaltyBoxCenter, self_pos, self_ang);
rcsc::Vector2D penaltyBoxTop(pitchHalfLength - penaltyAreaLength,
-penaltyAreaWidth / 2.);
addLandmarkFeatures(penaltyBoxTop, self_pos, self_ang);
rcsc::Vector2D penaltyBoxBot(pitchHalfLength - penaltyAreaLength,
penaltyAreaWidth / 2.);
addLandmarkFeatures(penaltyBoxBot, self_pos, self_ang);
// Corners of the Playable Area
rcsc::Vector2D centerField(0, 0);
addLandmarkFeatures(centerField, self_pos, self_ang);
rcsc::Vector2D cornerTopLeft(0, -pitchHalfWidth);
addLandmarkFeatures(cornerTopLeft, self_pos, self_ang);
rcsc::Vector2D cornerTopRight(pitchHalfLength, -pitchHalfWidth);
addLandmarkFeatures(cornerTopRight, self_pos, self_ang);
rcsc::Vector2D cornerBotRight(pitchHalfLength, pitchHalfWidth);
addLandmarkFeatures(cornerBotRight, self_pos, self_ang);
rcsc::Vector2D cornerBotLeft(0, pitchHalfWidth);
addLandmarkFeatures(cornerBotLeft, self_pos, self_ang);
// Distances to the edges of the playable area
if (self.posValid()) {
// Distance to Left field line
addDistFeature(self_pos.x, pitchHalfLength);
// Distance to Right field line
addDistFeature(pitchHalfLength - self_pos.x, pitchHalfLength);
// Distance to top field line
addDistFeature(pitchHalfWidth + self_pos.y, pitchWidth);
// Distance to Bottom field line
addDistFeature(pitchHalfWidth - self_pos.y, pitchWidth);
} else {
addFeature(0);
addFeature(0);
addFeature(0);
addFeature(0);
}
// ======================== BALL FEATURES ======================== //
const BallObject& ball = wm.ball();
// Angle and distance to the ball
addFeature(ball.rposValid() ? FEAT_MAX : FEAT_MIN);
if (ball.rposValid()) {
addAngFeature(ball.angleFromSelf());
addDistFeature(ball.distFromSelf(), maxHFORadius);
} else {
addFeature(0);
addFeature(0);
addFeature(0);
}
// Velocity and direction of the ball
addFeature(ball.velValid() ? FEAT_MAX : FEAT_MIN);
if (ball.velValid()) {
// SeverParam lists ballSpeedMax a 2.7 which is too low
addNormFeature(ball.vel().r(), 0., observedBallSpeedMax);
addAngFeature(ball.vel().th());
} else {
addFeature(0);
addFeature(0);
addFeature(0);
}
assert(featIndx == num_basic_features);
// ======================== TEAMMATE FEATURES ======================== //
// Vector of PlayerObject pointers sorted by increasing distance from self
int detected_teammates = 0;
const PlayerPtrCont& teammates = wm.teammatesFromSelf();
for (PlayerPtrCont::const_iterator it = teammates.begin();
it != teammates.end(); ++it) {
PlayerObject* teammate = *it;
if (teammate->pos().x > 0 && teammate->unum() > 0 &&
detected_teammates < numTeammates) {
addPlayerFeatures(*teammate, self_pos, self_ang);
detected_teammates++;
}
}
// Add zero features for any missing teammates
for (int i=detected_teammates; i<numTeammates; ++i) {
for (int j=0; j<features_per_player; ++j) {
addFeature(0);
}
}
// ======================== OPPONENT FEATURES ======================== //
int detected_opponents = 0;
const PlayerPtrCont& opponents = wm.opponentsFromSelf();
for (PlayerPtrCont::const_iterator it = opponents.begin();
it != opponents.end(); ++it) {
PlayerObject* opponent = *it;
if (opponent->pos().x > 0 && opponent->unum() > 0 &&
detected_opponents < numOpponents) {
addPlayerFeatures(*opponent, self_pos, self_ang);
detected_opponents++;
}
}
// Add zero features for any missing opponents
for (int i=detected_opponents; i<numOpponents; ++i) {
for (int j=0; j<features_per_player; ++j) {
addFeature(0);
}
}
assert(featIndx == numFeatures);
checkFeatures();
return feature_vec;
}
#ifndef HIGHLEVEL_FEATURE_EXTRACTOR_H
#define HIGHLEVEL_FEATURE_EXTRACTOR_H
#include <rcsc/player/player_agent.h>
#include "feature_extractor.h"
#include <vector>
class HighLevelFeatureExtractor : public FeatureExtractor {
public:
HighLevelFeatureExtractor(int num_teammates, int num_opponents,
bool playing_offense);
virtual ~HighLevelFeatureExtractor();
// Updated the state features stored in feature_vec
virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm);
protected:
// Number of features for non-player objects.
const static int num_basic_features = 58;
// Number of features for each player or opponent in game.
const static int features_per_player = 8;
int numTeammates; // Number of teammates in HFO
int numOpponents; // Number of opponents in HFO
bool playingOffense; // Are we playing offense or defense?
};
#endif // HIGHLEVEL_FEATURE_EXTRACTOR_H
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif
#include "lowlevel_feature_extractor.h"
#include <rcsc/common/server_param.h>
using namespace rcsc;
LowLevelFeatureExtractor::LowLevelFeatureExtractor(int num_teammates,
int num_opponents,
bool playing_offense) :
FeatureExtractor(),
numTeammates(num_teammates),
numOpponents(num_opponents),
playingOffense(playing_offense)
{
assert(numTeammates >= 0);
assert(numOpponents >= 0);
numFeatures = num_basic_features +
features_per_player * (numTeammates + numOpponents);
feature_vec.resize(numFeatures);
}
LowLevelFeatureExtractor::~LowLevelFeatureExtractor() {}
const std::vector<float>& LowLevelFeatureExtractor::ExtractFeatures(
const WorldModel& wm) {
featIndx = 0;
const ServerParam& SP = ServerParam::i();
// ======================== SELF FEATURES ======================== //
const SelfObject& self = wm.self();
const Vector2D& self_pos = self.pos();
const AngleDeg& self_ang = self.body();
addFeature(self.posValid() ? FEAT_MAX : FEAT_MIN);
// ADD_FEATURE(self_pos.x);
// ADD_FEATURE(self_pos.y);
// Direction and speed of the agent.
addFeature(self.velValid() ? FEAT_MAX : FEAT_MIN);
if (self.velValid()) {
addAngFeature(self_ang - self.vel().th());
addNormFeature(self.speed(), 0., observedSelfSpeedMax);
} else {
addFeature(0);
addFeature(0);
addFeature(0);
}
// Global Body Angle -- 0:right -90:up 90:down 180/-180:left
addAngFeature(self_ang);
// Neck Angle -- We probably don't need this unless we are
// controlling the neck manually.
// std::cout << "Face Error: " << self.faceError() << std::endl;
// if (self.faceValid()) {
// std::cout << "FaceAngle: " << self.face() << std::endl;
// }
addNormFeature(self.stamina(), 0., observedStaminaMax);
addFeature(self.isFrozen() ? FEAT_MAX : FEAT_MIN);
// Probabilities - Do we want these???
// std::cout << "catchProb: " << self.catchProbability() << std::endl;
// std::cout << "tackleProb: " << self.tackleProbability() << std::endl;
// std::cout << "fouldProb: " << self.foulProbability() << std::endl;
// Features indicating if we are colliding with an object
addFeature(self.collidesWithBall() ? FEAT_MAX : FEAT_MIN);
addFeature(self.collidesWithPlayer() ? FEAT_MAX : FEAT_MIN);
addFeature(self.collidesWithPost() ? FEAT_MAX : FEAT_MIN);
addFeature(self.isKickable() ? FEAT_MAX : FEAT_MIN);
// inertiaPoint estimates the ball point after a number of steps
// self.inertiaPoint(n_steps);
// ======================== LANDMARK FEATURES ======================== //
// Top Bottom Center of Goal
rcsc::Vector2D goalCenter(pitchHalfLength, 0);
addLandmarkFeatures(goalCenter, self_pos, self_ang);
rcsc::Vector2D goalPostTop(pitchHalfLength, -goalHalfWidth);
addLandmarkFeatures(goalPostTop, self_pos, self_ang);
rcsc::Vector2D goalPostBot(pitchHalfLength, goalHalfWidth);
addLandmarkFeatures(goalPostBot, self_pos, self_ang);
// Top Bottom Center of Penalty Box
rcsc::Vector2D penaltyBoxCenter(pitchHalfLength - penaltyAreaLength, 0);
addLandmarkFeatures(penaltyBoxCenter, self_pos, self_ang);
rcsc::Vector2D penaltyBoxTop(pitchHalfLength - penaltyAreaLength,
-penaltyAreaWidth / 2.);
addLandmarkFeatures(penaltyBoxTop, self_pos, self_ang);
rcsc::Vector2D penaltyBoxBot(pitchHalfLength - penaltyAreaLength,
penaltyAreaWidth / 2.);
addLandmarkFeatures(penaltyBoxBot, self_pos, self_ang);
// Corners of the Playable Area
rcsc::Vector2D centerField(0, 0);
addLandmarkFeatures(centerField, self_pos, self_ang);
rcsc::Vector2D cornerTopLeft(0, -pitchHalfWidth);
addLandmarkFeatures(cornerTopLeft, self_pos, self_ang);
rcsc::Vector2D cornerTopRight(pitchHalfLength, -pitchHalfWidth);
addLandmarkFeatures(cornerTopRight, self_pos, self_ang);
rcsc::Vector2D cornerBotRight(pitchHalfLength, pitchHalfWidth);
addLandmarkFeatures(cornerBotRight, self_pos, self_ang);
rcsc::Vector2D cornerBotLeft(0, pitchHalfWidth);
addLandmarkFeatures(cornerBotLeft, self_pos, self_ang);
// Distances to the edges of the playable area
if (self.posValid()) {
// Distance to Left field line
addDistFeature(self_pos.x, pitchHalfLength);
// Distance to Right field line
addDistFeature(pitchHalfLength - self_pos.x, pitchHalfLength);
// Distance to top field line
addDistFeature(pitchHalfWidth + self_pos.y, pitchWidth);
// Distance to Bottom field line
addDistFeature(pitchHalfWidth - self_pos.y, pitchWidth);
} else {
addFeature(0);
addFeature(0);
addFeature(0);
addFeature(0);
}
// ======================== BALL FEATURES ======================== //
const BallObject& ball = wm.ball();
// Angle and distance to the ball
addFeature(ball.rposValid() ? FEAT_MAX : FEAT_MIN);
if (ball.rposValid()) {
addAngFeature(ball.angleFromSelf());
addDistFeature(ball.distFromSelf(), maxHFORadius);
} else {
addFeature(0);
addFeature(0);
addFeature(0);
}
// Velocity and direction of the ball
addFeature(ball.velValid() ? FEAT_MAX : FEAT_MIN);
if (ball.velValid()) {
// SeverParam lists ballSpeedMax a 2.7 which is too low
addNormFeature(ball.vel().r(), 0., observedBallSpeedMax);
addAngFeature(ball.vel().th());
} else {
addFeature(0);
addFeature(0);
addFeature(0);
}
assert(featIndx == num_basic_features);
// ======================== TEAMMATE FEATURES ======================== //
// Vector of PlayerObject pointers sorted by increasing distance from self
int detected_teammates = 0;
const PlayerPtrCont& teammates = wm.teammatesFromSelf();
for (PlayerPtrCont::const_iterator it = teammates.begin();
it != teammates.end(); ++it) {
PlayerObject* teammate = *it;
if (teammate->pos().x > 0 && teammate->unum() > 0 &&
detected_teammates < numTeammates) {
addPlayerFeatures(*teammate, self_pos, self_ang);
detected_teammates++;
}
}
// Add zero features for any missing teammates
for (int i=detected_teammates; i<numTeammates; ++i) {
for (int j=0; j<features_per_player; ++j) {
addFeature(0);
}
}
// ======================== OPPONENT FEATURES ======================== //
int detected_opponents = 0;
const PlayerPtrCont& opponents = wm.opponentsFromSelf();
for (PlayerPtrCont::const_iterator it = opponents.begin();
it != opponents.end(); ++it) {
PlayerObject* opponent = *it;
if (opponent->pos().x > 0 && opponent->unum() > 0 &&
detected_opponents < numOpponents) {
addPlayerFeatures(*opponent, self_pos, self_ang);
detected_opponents++;
}
}
// Add zero features for any missing opponents
for (int i=detected_opponents; i<numOpponents; ++i) {
for (int j=0; j<features_per_player; ++j) {
addFeature(0);
}
}
assert(featIndx == numFeatures);
checkFeatures();
return feature_vec;
}
#ifndef LOWLEVEL_FEATURE_EXTRACTOR_H
#define LOWLEVEL_FEATURE_EXTRACTOR_H
#include <rcsc/player/player_agent.h>
#include "feature_extractor.h"
#include <vector>
class LowLevelFeatureExtractor : public FeatureExtractor {
public:
LowLevelFeatureExtractor(int num_teammates, int num_opponents,
bool playing_offense);
virtual ~LowLevelFeatureExtractor();
// Updated the state features stored in feature_vec
virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm);
protected:
// Number of features for non-player objects.
const static int num_basic_features = 58;
// Number of features for each player or opponent in game.
const static int features_per_player = 8;
int numTeammates; // Number of teammates in HFO
int numOpponents; // Number of opponents in HFO
bool playingOffense; // Are we playing offense or defense?
};
#endif // LOWLEVEL_FEATURE_EXTRACTOR_H
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment