Commit 3c70ad03 authored by Matthew Hausknecht's avatar Matthew Hausknecht Committed by GitHub

Merge pull request #39 from drallensmith/docs_update

Feedback enabled
parents a02aa143 28c26906
No preview for this file type
......@@ -324,6 +324,10 @@ features.
\item [$3O$] {\textbf{X, Y, and Uniform Number of
Opponents} - For each opponent: the x-position, y-position and
uniform number of that opponent.}
\item [$+1$] {\textbf{Last\_Action\_Success\_Possible} - Whether there is any chance
the last action taken was successful, either in accomplishing the
usual intent of the action or (primarily for the offense) in some other way such as
getting out of a goal-collision state. 1 for yes, -1 for no.}
\end{enumerate}
\begin{figure}[htp]
......@@ -518,6 +522,10 @@ low-level features:
sorted by proximity to the agent.}
\item [$O$] {\textbf{Opponent Uniform Nums} [Unum] One uniform number for each opponent active in HFO,
sorted by proximity to the player.}
\item [$+1$] {\textbf{Last\_Action\_Success\_Possible} [Boolean] Whether there is any chance
the last action taken was successful, either in accomplishing the
usual intent of the action or (primarily for the offense) in some other way such as getting
out of a goal-collision state.}
\end{enumerate}
\section{Action Space}
......
......@@ -19,7 +19,7 @@ except ImportError:
' run: \"pip install .\"')
exit()
GOAL_POS_X = 1.0
GOAL_POS_X = 0.9
GOAL_POS_Y = 0.0
# below - from hand_coded_defense_agent.cpp except LOW_KICK_DIST
......
......@@ -6,11 +6,12 @@ import os
hfo_lib = cdll.LoadLibrary(os.path.join(os.path.dirname(__file__),
'libhfo_c.so'))
''' Possible feature sets '''
"""Possible feature sets"""
NUM_FEATURE_SETS = 2
LOW_LEVEL_FEATURE_SET, HIGH_LEVEL_FEATURE_SET = list(range(NUM_FEATURE_SETS))
''' An enum of the possible HFO actions
"""
An enum of the possible HFO actions, including:
[Low-Level] Dash(power, relative_direction)
[Low-Level] Turn(direction)
[Low-Level] Tackle(direction)
......@@ -25,25 +26,50 @@ LOW_LEVEL_FEATURE_SET, HIGH_LEVEL_FEATURE_SET = list(range(NUM_FEATURE_SETS))
[High-Level] Dribble(): Offensive dribble
[High-Level] Catch(): Catch the ball (Goalie Only)
NOOP(): Do Nothing
QUIT(): Quit the game '''
QUIT(): Quit the game
"""
NUM_HFO_ACTIONS = 20
DASH, TURN, TACKLE, KICK, KICK_TO, MOVE_TO, DRIBBLE_TO, INTERCEPT, \
MOVE, SHOOT, PASS, DRIBBLE, CATCH, NOOP, QUIT, REDUCE_ANGLE_TO_GOAL,MARK_PLAYER,DEFEND_GOAL,GO_TO_BALL,REORIENT = list(range(NUM_HFO_ACTIONS))
ACTION_STRINGS = ["Dash", "Turn", "Tackle", "Kick", "KickTo", "MoveTo", "DribbleTo", "Intercept", "Move", "Shoot", "Pass", "Dribble", "Catch", "No-op", "Quit", "Reduce_Angle_To_Goal", "Mark_Player", "Defend_Goal", "Go_To_Ball", "Reorient"]
''' Possible game status
DASH,TURN,TACKLE,KICK,KICK_TO,MOVE_TO,DRIBBLE_TO,INTERCEPT,MOVE,SHOOT,PASS,DRIBBLE,CATCH,NOOP,QUIT,REDUCE_ANGLE_TO_GOAL,MARK_PLAYER,DEFEND_GOAL,GO_TO_BALL,REORIENT = list(range(NUM_HFO_ACTIONS))
ACTION_STRINGS = {DASH: "Dash",
TURN: "Turn",
TACKLE: "Tackle",
KICK: "Kick",
KICK_TO: "KickTo",
MOVE_TO: "MoveTo",
DRIBBLE_TO: "DribbleTo",
INTERCEPT: "Intercept",
MOVE: "Move",
SHOOT: "Shoot",
PASS: "Pass",
DRIBBLE: "Dribble",
CATCH: "Catch",
NOOP: "No-op",
QUIT: "Quit",
REDUCE_ANGLE_TO_GOAL: "Reduce_Angle_To_Goal",
MARK_PLAYER: "Mark_Player",
DEFEND_GOAL: "Defend_Goal",
GO_TO_BALL: "Go_To_Ball",
REORIENT: "Reorient"}
"""
Possible game statuses:
[IN_GAME] Game is currently active
[GOAL] A goal has been scored by the offense
[CAPTURED_BY_DEFENSE] The defense has captured the ball
[OUT_OF_BOUNDS] Ball has gone out of bounds
[OUT_OF_TIME] Trial has ended due to time limit
[SERVER_DOWN] Server is not alive
'''
"""
NUM_GAME_STATUS_STATES = 6
IN_GAME, GOAL, CAPTURED_BY_DEFENSE, OUT_OF_BOUNDS, OUT_OF_TIME, SERVER_DOWN = list(range(NUM_GAME_STATUS_STATES))
STATUS_STRINGS = ["InGame", "Goal", "CapturedByDefense", "OutOfBounds", "OutOfTime", "ServerDown"]
''' Possible sides '''
STATUS_STRINGS = {IN_GAME: "InGame",
GOAL: "Goal",
CAPTURED_BY_DEFENSE: "CapturedByDefense",
OUT_OF_BOUNDS: "OutOfBounds",
OUT_OF_TIME: "OutOfTime",
SERVER_DOWN: "ServerDown"}
"""Possible sides."""
RIGHT, NEUTRAL, LEFT = list(range(-1,2))
class Player(Structure): pass
......@@ -109,7 +135,14 @@ class HFOEnvironment(object):
play_goalie: is this player the goalie
record_dir: record agent's states/actions/rewards to this directory
"""
hfo_lib.connectToServer(self.obj, feature_set, config_dir.encode('utf-8'), server_port,server_addr.encode('utf-8'), team_name.encode('utf-8'), play_goalie, record_dir.encode('utf-8'))
hfo_lib.connectToServer(self.obj,
feature_set,
config_dir.encode('utf-8'),
server_port,server_addr.encode('utf-8'),
team_name.encode('utf-8'),
play_goalie,
record_dir.encode('utf-8'))
def getStateSize(self):
""" Returns the number of state features """
return hfo_lib.getStateSize(self.obj)
......
......@@ -150,6 +150,9 @@ Agent::Agent()
// set communication planner
M_communication = Communication::Ptr(new SampleCommunication());
// setup last_action variable
last_action_status = false;
}
Agent::~Agent() {
......@@ -249,83 +252,97 @@ void Agent::actionImpl() {
}
// For now let's not worry about turning the neck or setting the vision.
// However, do this now so doesn't override anything changed by the requested action.
// TODO for librcsc: setViewActionDefault, setNeckActionDefault that will not overwrite if already set.
// But do the settings now, so that doesn't override any set by the actions below.
// TODO: Add setViewActionDefault, setNeckActionDefault to librcsc that only set if not already set.
const WorldModel & wm = this->world();
this->setViewAction(new View_Tactical());
this->setNeckAction(new Neck_TurnToBallOrScan());
if (wm.ball().posValid()) {
this->setNeckAction(new Neck_TurnToBallOrScan()); // if not ball().posValid(), requests possibly-invalid queuedNextBallPos()
} else {
this->setNeckAction(new Neck_ScanField()); // equivalent to Neck_TurnToBall()
}
switch(requested_action) {
case DASH:
this->doDash(params[0], params[1]);
last_action_status = this->doDash(params[0], params[1]);
break;
case TURN:
this->doTurn(params[0]);
last_action_status = this->doTurn(params[0]);
break;
case TACKLE:
this->doTackle(params[0], false);
last_action_status = this->doTackle(params[0], false);
break;
case KICK:
this->doKick(params[0], params[1]);
last_action_status = this->doKick(params[0], params[1]);
break;
case KICK_TO:
if (feature_extractor != NULL) {
Body_SmartKick(Vector2D(feature_extractor->absoluteXPos(params[0]),
last_action_status = Body_SmartKick(Vector2D(feature_extractor->absoluteXPos(params[0]),
feature_extractor->absoluteYPos(params[1])),
params[2], params[2] * 0.99, 3).execute(this);
}
break;
case MOVE_TO:
if (feature_extractor != NULL) {
Body_GoToPoint(Vector2D(feature_extractor->absoluteXPos(params[0]),
last_action_status = Body_GoToPoint(Vector2D(feature_extractor->absoluteXPos(params[0]),
feature_extractor->absoluteYPos(params[1])), 0.25,
ServerParam::i().maxDashPower()).execute(this);
last_action_status |= wm.self().collidesWithPost(); // can get out of collision w/post
}
break;
case DRIBBLE_TO:
if (feature_extractor != NULL) {
Body_Dribble(Vector2D(feature_extractor->absoluteXPos(params[0]),
last_action_status = Body_Dribble(Vector2D(feature_extractor->absoluteXPos(params[0]),
feature_extractor->absoluteYPos(params[1])), 1.0,
ServerParam::i().maxDashPower(), 2).execute(this);
last_action_status |= wm.self().collidesWithPost(); // ditto
}
break;
case INTERCEPT:
Body_Intercept().execute(this);
last_action_status = Body_Intercept().execute(this);
last_action_status |= wm.self().collidesWithPost(); // ditto
break;
case MOVE:
this->doMove();
last_action_status = this->doMove();
break;
case SHOOT:
this->doSmartKick();
last_action_status = this->doSmartKick();
break;
case PASS:
this->doPassTo(int(params[0]));
last_action_status = this->doPassTo(int(params[0]));
break;
case DRIBBLE:
this->doDribble();
last_action_status = this->doDribble();
break;
case CATCH:
this->doCatch();
last_action_status = this->doCatch();
break;
case NOOP:
last_action_status = false;
break;
case QUIT:
std::cout << "Got quit from agent." << std::endl;
handleExit();
return;
case REDUCE_ANGLE_TO_GOAL:
this->doReduceAngleToGoal();
last_action_status = this->doReduceAngleToGoal();
break;
case MARK_PLAYER:
this->doMarkPlayer(int(params[0]));
last_action_status = this->doMarkPlayer(int(params[0]));
break;
case DEFEND_GOAL:
this->doDefendGoal();
last_action_status = this->doDefendGoal();
break;
case GO_TO_BALL:
this->doGoToBall();
last_action_status = this->doGoToBall();
break;
case REORIENT:
this->doReorient();
last_action_status = this->doReorient();
break;
default:
std::cerr << "ERROR: Unsupported Action: "
......@@ -381,7 +398,8 @@ void
Agent::UpdateFeatures()
{
if (feature_extractor != NULL) {
state = feature_extractor->ExtractFeatures(this->world());
state = feature_extractor->ExtractFeatures(this->world(),
getLastActionStatus());
}
}
......@@ -617,7 +635,11 @@ Agent::doPreprocess()
wm.self().tackleExpires() );
// face neck to ball
this->setViewAction( new View_Tactical() );
if (wm.ball().posValid()) {
this->setNeckAction( new Neck_TurnToBallOrScan() );
} else{
this->setNeckAction( new Neck_TurnToBall() );
}
return true;
}
......@@ -642,8 +664,7 @@ Agent::doPreprocess()
{
dlog.addText( Logger::TEAM,
__FILE__": invalid my pos" );
Bhv_Emergency().execute( this ); // includes change view
return true;
return Bhv_Emergency().execute( this ); // includes change view
}
//
......@@ -659,8 +680,7 @@ Agent::doPreprocess()
dlog.addText( Logger::TEAM,
__FILE__": search ball" );
this->setViewAction( new View_Tactical() );
Bhv_NeckBodyToBall().execute( this );
return true;
return Bhv_NeckBodyToBall().execute( this );
}
//
......@@ -774,19 +794,15 @@ Agent::doReorient()
//
// ball localization error
//
const int count_thr = ( wm.self().goalie()
? 10
: 5 );
if ( wm.ball().posCount() > count_thr
|| ( wm.gameMode().type() != GameMode::PlayOn
&& wm.ball().seenPosCount() > count_thr + 10 ) )
{
const BallObject& ball = wm.ball();
if (! ( ball.posValid() && ball.velValid() )) {
dlog.addText( Logger::TEAM,
__FILE__": search ball" );
return Bhv_NeckBodyToBall().execute( this );
}
//
// check pass message
//
......@@ -795,10 +811,19 @@ Agent::doReorient()
return true;
}
const BallObject& ball = wm.ball();
if (! ( ball.rposValid() && ball.velValid() )) {
//
// ball localization error
//
const int count_thr = ( wm.self().goalie()
? 10
: 5 );
if ( wm.ball().posCount() > count_thr
|| ( wm.gameMode().type() != GameMode::PlayOn
&& wm.ball().seenPosCount() > count_thr + 10 ) )
{
dlog.addText( Logger::TEAM,
__FILE__": search ball" );
return Bhv_NeckBodyToBall().execute( this );
}
......@@ -847,9 +872,10 @@ Agent::doSmartKick()
ShootGenerator::instance().courses(this->world(), false);
ShootGenerator::Container::const_iterator best_shoot
= std::min_element(cont.begin(), cont.end(), ShootGenerator::ScoreCmp());
Body_SmartKick(best_shoot->target_point_, best_shoot->first_ball_speed_,
best_shoot->first_ball_speed_ * 0.99, 3).execute(this);
return true;
return Body_SmartKick(best_shoot->target_point_,
best_shoot->first_ball_speed_,
best_shoot->first_ball_speed_ * 0.99,
3).execute(this);
}
......@@ -870,9 +896,12 @@ bool
Agent::doPassTo(int receiver)
{
Force_Pass pass;
pass.get_pass_to_player(this->world(), receiver);
pass.execute(this);
const WorldModel & wm = this->world();
pass.get_pass_to_player(wm, receiver);
if (pass.execute(this) || wm.self().collidesWithBall()) { // can sometimes fix
return true;
}
return false;
}
/*-------------------------------------------------------------------*/
......@@ -889,10 +918,13 @@ Agent::doDribble()
M_action_generator = ActionGenerator::ConstPtr(g);
ActionChainHolder::instance().setFieldEvaluator( M_field_evaluator );
ActionChainHolder::instance().setActionGenerator( M_action_generator );
doPreprocess();
bool preprocess_success = doPreprocess();
ActionChainHolder::instance().update( world() );
Bhv_ChainAction(ActionChainHolder::instance().graph()).execute(this);
if (Bhv_ChainAction(ActionChainHolder::instance().graph()).execute(this) ||
preprocess_success) {
return true;
}
return false;
}
/*-------------------------------------------------------------------*/
......@@ -903,9 +935,7 @@ bool
Agent::doMove()
{
Strategy::instance().update( world() );
int role_num = Strategy::i().roleNumber(world().self().unum());
Bhv_BasicMove().execute(this);
return true;
return Bhv_BasicMove().execute(this);
}
/*-------------------------------------------------------------------*/
......@@ -948,8 +978,12 @@ bool Agent::doMarkPlayer(int unum) {
}
double x = player_pos.x + (kicker_pos.x - player_pos.x)*0.1;
double y = player_pos.y + (kicker_pos.y - player_pos.y)*0.1;
Body_GoToPoint(Vector2D(x,y), 0.25, ServerParam::i().maxDashPower()).execute(this);
if (Body_GoToPoint(Vector2D(x,y), 0.25, ServerParam::i().maxDashPower()).execute(this) ||
wm.self().collidesWithPost()) { // latter because sometimes fixes
return true;
}
return false;
}
/*-------------------------------------------------------------------*/
......@@ -958,7 +992,7 @@ bool Agent::doMarkPlayer(int unum) {
* This action cuts off the angle between the shooter and the goal the players always move to a dynamic line in between the kicker and the goal.
*/
/* Comparator for sorting teammated based on y positions.*/
/* Comparator for sorting teammates based on y positions.*/
bool compare_y_pos (PlayerObject* i, PlayerObject* j) {
return i->pos().y < j->pos().y;
}
......@@ -970,7 +1004,12 @@ bool Agent::doReduceAngleToGoal() {
const PlayerPtrCont::const_iterator o_end = wm.opponentsFromSelf().end();
Vector2D ball_pos = wm.ball().pos();
const BallObject& ball = wm.ball();
if (! ball.posValid()) {
return false;
}
Vector2D ball_pos = ball.pos();
double nearRatio = 0.9;
const PlayerPtrCont::const_iterator o_t_end = wm.teammatesFromSelf().end();
......@@ -1036,28 +1075,40 @@ bool Agent::doReduceAngleToGoal() {
double dist_to_end2 = targetLineEnd2.dist2(ball_pos);
double ratio = dist_to_end2/(dist_to_end1+dist_to_end2);
Vector2D target = targetLineEnd1 * ratio + targetLineEnd2 * (1-ratio);
Body_GoToPoint(target, 0.25, ServerParam::i().maxDashPower()).execute(this);
if (Body_GoToPoint(target, 0.25, ServerParam::i().maxDashPower()).execute(this) ||
wm.self().collidesWithPost()) { // latter because sometimes fixes
return true;
}
return false;
}
/*-------------------------------------------------------------------*/
/*!
*
* This action cuts off the angle between the shooter and the goal the players always moves on a fixed line.
* This action cuts off the angle between the shooter and the goal; the player always moves on a fixed line.
*/
bool Agent::doDefendGoal() {
const WorldModel & wm = this->world();
Vector2D goal_pos1( -ServerParam::i().pitchHalfLength() + ServerParam::i().goalAreaLength(), ServerParam::i().goalHalfWidth() );
Vector2D goal_pos2( -ServerParam::i().pitchHalfLength() + ServerParam::i().goalAreaLength(), -ServerParam::i().goalHalfWidth() );
Vector2D ball_pos = wm.ball().pos();
const BallObject& ball = wm.ball();
if (! ball.posValid()) {
return false;
}
Vector2D ball_pos = ball.pos();
double dist_to_post1 = goal_pos1.dist2(ball_pos);
double dist_to_post2 = goal_pos2.dist2(ball_pos);
double ratio = dist_to_post2/(dist_to_post1+dist_to_post2);
Vector2D target = goal_pos1 * ratio + goal_pos2 * (1-ratio);
Body_GoToPoint(target, 0.25, ServerParam::i().maxDashPower()).execute(this);
if (Body_GoToPoint(target, 0.25, ServerParam::i().maxDashPower()).execute(this) ||
wm.self().collidesWithPost()) { // latter because sometimes fixes
return true;
}
return false;
}
/*-------------------------------------------------------------------*/
......@@ -1070,8 +1121,11 @@ bool Agent::doDefendGoal() {
bool Agent::doGoToBall() {
const WorldModel & wm = this->world();
Body_GoToPoint(wm.ball().pos(), 0.25, ServerParam::i().maxDashPower()).execute(this);
return true;
const BallObject& ball = wm.ball();
if (! ball.posValid()) {
return false;
}
return Body_GoToPoint(ball.pos(), 0.25, ServerParam::i().maxDashPower()).execute(this);
}
/*-------------------------------------------------------------------*/
......
// -*-c++-*-
#ifndef AGENT_H
#define AGENT_H
......@@ -63,6 +65,7 @@ protected:
std::vector<float> params; // Parameters of current action
int num_teammates; // Number of teammates
int num_opponents; // Number of opponents
bool last_action_status; // Recorded return status of last action
public:
inline const std::vector<float>& getState() { return state; }
......@@ -72,6 +75,7 @@ protected:
int getUnum(); // Returns the uniform number of the player
inline int getNumTeammates() { return num_teammates; }
inline int getNumOpponents() { return num_opponents; }
inline bool getLastActionStatus() { return last_action_status; }
inline void setFeatureSet(hfo::feature_set_t fset) { feature_set = fset; }
inline std::vector<float>* mutable_params() { return &params; }
......
......@@ -84,16 +84,17 @@ Bhv_BasicMove::execute( PlayerAgent * agent )
{
dlog.addText( Logger::TEAM,
__FILE__": intercept" );
Body_Intercept().execute( agent );
bool success = Body_Intercept().execute( agent );
agent->setNeckAction( new Neck_OffensiveInterceptNeck() );
return true;
return success;
}
const Vector2D target_point = Strategy::i().getPosition( wm.self().unum() );
const double dash_power = Strategy::get_normal_dash_power( wm );
double dist_thr = wm.ball().distFromSelf() * 0.1;
const BallObject& ball = wm.ball();
double dist_thr = ball.distFromSelf() * 0.1;
if ( dist_thr < 1.0 ) dist_thr = 1.0;
dlog.addText( Logger::TEAM,
......@@ -105,21 +106,27 @@ Bhv_BasicMove::execute( PlayerAgent * agent )
agent->debugClient().setTarget( target_point );
agent->debugClient().addCircle( target_point, dist_thr );
if ( ! Body_GoToPoint( target_point, dist_thr, dash_power
).execute( agent ) )
{
Body_TurnToBall().execute( agent );
bool success = false;
if ( Body_GoToPoint( target_point, dist_thr, dash_power
).execute( agent ) ||
Body_TurnToBall().execute( agent ) ) {
if (ball.posValid() || wm.self().collidesWithPost()) {
success = true;
}
} else {
success = false;
}
if ( wm.existKickableOpponent()
&& wm.ball().distFromSelf() < 18.0 )
{
&& ball.distFromSelf() < 18.0 ) {
agent->setNeckAction( new Neck_TurnToBall() );
}
else
{
} else if ( ball.posValid() ) {
agent->setNeckAction( new Neck_TurnToBallOrScan() );
} else {
agent->setNeckAction( new Neck_TurnToBall() );
}
return true;
return success;
}
......@@ -41,7 +41,7 @@ enum action_t
REORIENT // [High-Level] Handle lost position of self/ball, misc other situations; variant of doPreprocess called in DRIBBLE
};
// Status of a HFO game
// Status of an HFO game
enum status_t
{
IN_GAME, // Game is currently active
......@@ -176,7 +176,7 @@ inline std::string ActionToString(action_t action) {
};
/**
* Returns a string representation of a game_status.
* Returns a string representation of a game status.
*/
inline std::string StatusToString(status_t status) {
switch (status) {
......
// -*-c++-*-
#ifndef FEATURE_EXTRACTOR_H
#define FEATURE_EXTRACTOR_H
......@@ -12,7 +14,8 @@ public:
virtual ~FeatureExtractor();
// Updated the state features stored in feature_vec
virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm) = 0;
virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm,
bool last_action_status) = 0;
// Record the current state
void LogFeatures();
......
......@@ -16,13 +16,15 @@ HighLevelFeatureExtractor::HighLevelFeatureExtractor(int num_teammates,
assert(numOpponents >= 0);
numFeatures = num_basic_features + features_per_teammate * numTeammates
+ features_per_opponent * numOpponents;
numFeatures++; // action status
feature_vec.resize(numFeatures);
}
HighLevelFeatureExtractor::~HighLevelFeatureExtractor() {}
const std::vector<float>& HighLevelFeatureExtractor::ExtractFeatures(
const WorldModel& wm) {
const std::vector<float>&
HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
bool last_action_status) {
featIndx = 0;
const ServerParam& SP = ServerParam::i();
const SelfObject& self = wm.self();
......@@ -178,6 +180,12 @@ const std::vector<float>& HighLevelFeatureExtractor::ExtractFeatures(
addFeature(FEAT_INVALID);
}
if (last_action_status) {
addFeature(FEAT_MAX);
} else {
addFeature(FEAT_MIN);
}
assert(featIndx == numFeatures);
// checkFeatures();
return feature_vec;
......
// -*-c++-*-
#ifndef HIGHLEVEL_FEATURE_EXTRACTOR_H
#define HIGHLEVEL_FEATURE_EXTRACTOR_H
......@@ -18,7 +20,8 @@ public:
virtual ~HighLevelFeatureExtractor();
// Updated the state features stored in feature_vec
virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm);
virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm,
bool last_action_status);
protected:
// Number of features for non-player objects.
......
......@@ -17,13 +17,15 @@ LowLevelFeatureExtractor::LowLevelFeatureExtractor(int num_teammates,
numFeatures = num_basic_features +
features_per_player * (numTeammates + numOpponents);
numFeatures += numTeammates + numOpponents; // Uniform numbers
numFeatures++; // action state
feature_vec.resize(numFeatures);
}
LowLevelFeatureExtractor::~LowLevelFeatureExtractor() {}
const std::vector<float>& LowLevelFeatureExtractor::ExtractFeatures(
const WorldModel& wm) {
const std::vector<float>&
LowLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
bool last_action_status) {
featIndx = 0;
const ServerParam& SP = ServerParam::i();
// ======================== SELF FEATURES ======================== //
......@@ -197,7 +199,7 @@ const std::vector<float>& LowLevelFeatureExtractor::ExtractFeatures(
detected_teammates++;
}
}
// Add -2 features for any missing teammates
// Add -1 features for any missing teammates
for (int i=detected_teammates; i<numTeammates; ++i) {
addFeature(FEAT_MIN);
}
......@@ -212,11 +214,17 @@ const std::vector<float>& LowLevelFeatureExtractor::ExtractFeatures(
detected_opponents++;
}
}
// Add -2 features for any missing opponents
// Add -1 features for any missing opponents
for (int i=detected_opponents; i<numOpponents; ++i) {
addFeature(FEAT_MIN);
}
if (last_action_status) {
addFeature(FEAT_MAX);
} else {
addFeature(FEAT_MIN);
}
assert(featIndx == numFeatures);
checkFeatures();
return feature_vec;
......
// -*-c++-*-
#ifndef LOWLEVEL_FEATURE_EXTRACTOR_H
#define LOWLEVEL_FEATURE_EXTRACTOR_H
......@@ -12,7 +14,8 @@ public:
virtual ~LowLevelFeatureExtractor();
// Updated the state features stored in feature_vec
virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm);
virtual const std::vector<float>& ExtractFeatures(const rcsc::WorldModel& wm,
bool last_action_status);
protected:
// Number of features for non-player objects.
......
......@@ -257,7 +257,7 @@ SamplePlayer::actionImpl()
lastTrainerMessageTime = audioSensor().trainerMessageTime().cycle();
}
if (feature_extractor != NULL) {
feature_extractor->ExtractFeatures(this->world());
feature_extractor->ExtractFeatures(this->world(), true);
feature_extractor->LogFeatures();
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment