highlevel_feature_extractor.cpp 7.03 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "highlevel_feature_extractor.h"
#include <rcsc/common/server_param.h>

using namespace rcsc;

HighLevelFeatureExtractor::HighLevelFeatureExtractor(int num_teammates,
                                                     int num_opponents,
                                                     bool playing_offense) :
13
  FeatureExtractor(num_teammates, num_opponents, playing_offense)
14 15 16
{
  assert(numTeammates >= 0);
  assert(numOpponents >= 0);
17 18
  numFeatures = num_basic_features + features_per_teammate * numTeammates
      + features_per_opponent * numOpponents;
19
  numFeatures++; // action status
20 21 22 23 24
  feature_vec.resize(numFeatures);
}

HighLevelFeatureExtractor::~HighLevelFeatureExtractor() {}

25 26 27
const std::vector<float>&
HighLevelFeatureExtractor::ExtractFeatures(const rcsc::WorldModel& wm,
					   bool last_action_status) {
28 29 30 31
  featIndx = 0;
  const ServerParam& SP = ServerParam::i();
  const SelfObject& self = wm.self();
  const Vector2D& self_pos = self.pos();
32
  const float self_ang = self.body().radian();
33 34
  const PlayerPtrCont& teammates = wm.teammatesFromSelf();
  const PlayerPtrCont& opponents = wm.opponentsFromSelf();
35 36 37
  float maxR = sqrtf(SP.pitchHalfLength() * SP.pitchHalfLength()
                     + SP.pitchHalfWidth() * SP.pitchHalfWidth());
  // features about self pos
38 39
  // Allow the agent to go 10% over the playfield in any direction
  float tolerance_x = .1 * SP.pitchHalfLength();
drallensmith's avatar
drallensmith committed
40
  float tolerance_y = .1 * SP.pitchHalfWidth();
41 42 43 44 45 46 47 48 49 50 51
  // Feature[0]: X-postion
  if (playingOffense) {
    addNormFeature(self_pos.x, -tolerance_x, SP.pitchHalfLength() + tolerance_x);
  } else {
    addNormFeature(self_pos.x, -SP.pitchHalfLength()-tolerance_x, tolerance_x);
  }

  // Feature[1]: Y-Position
  addNormFeature(self_pos.y, -SP.pitchHalfWidth() - tolerance_y,
                 SP.pitchHalfWidth() + tolerance_y);

52
  // Feature[2]: Self Angle
53
  addNormFeature(self_ang, -M_PI, M_PI);
54

55 56
  float r;
  float th;
57
  // Features about the ball
58 59
  Vector2D ball_pos = wm.ball().pos();
  angleDistToPoint(self_pos, ball_pos, th, r);
60
  // Feature[3] and [4]: (x,y) postition of the ball
61 62 63 64 65
  if (playingOffense) {
    addNormFeature(ball_pos.x, -tolerance_x, SP.pitchHalfLength() + tolerance_x);
  } else {
    addNormFeature(ball_pos.x, -SP.pitchHalfLength()-tolerance_x, tolerance_x);
  }
66
  addNormFeature(ball_pos.y, -SP.pitchHalfWidth() - tolerance_y, SP.pitchHalfWidth() + tolerance_y);
67
  // Feature[5]: Able to kick
68
  addNormFeature(self.isKickable(), false, true);
69

70
  // Features about distance to goal center
71
  Vector2D goalCenter(SP.pitchHalfLength(), 0);
72 73 74
  if (!playingOffense) {
    goalCenter.assign(-SP.pitchHalfLength(), 0);
  }
75
  angleDistToPoint(self_pos, goalCenter, th, r);
76
  // Feature[6]: Goal Center Distance
77
  addNormFeature(r, 0, maxR);
78
  // Feature[7]: Angle to goal center
79
  addNormFeature(th, -M_PI, M_PI);
80
  // Feature[8]: largest open goal angle
81
  addNormFeature(calcLargestGoalAngle(wm, self_pos), 0, M_PI);
82 83 84 85 86 87 88
  // Feature[9]: Dist to our closest opp
  if (numOpponents > 0) {
    calcClosestOpp(wm, self_pos, th, r);
    addNormFeature(r, 0, maxR);
  } else {
    addFeature(FEAT_INVALID);
  }
89

90
  // Features[9 - 9+T]: teammate's open angle to goal
91
  int detected_teammates = 0;
92 93
  for (PlayerPtrCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) {
    const PlayerObject* teammate = *it;
94
    if (valid(teammate) && detected_teammates < numTeammates) {
95
      addNormFeature(calcLargestGoalAngle(wm, teammate->pos()), 0, M_PI);
96 97
      detected_teammates++;
    }
98
  }
99 100
  // Add zero features for any missing teammates
  for (int i=detected_teammates; i<numTeammates; ++i) {
101
    addFeature(FEAT_INVALID);
102 103
  }

104
  // Features[9+T - 9+2T]: teammates' dists to closest opps
105 106
  if (numOpponents > 0) {
    detected_teammates = 0;
107 108
    for (PlayerPtrCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) {
      const PlayerObject* teammate = *it;
109
      if (valid(teammate) && detected_teammates < numTeammates) {
110
        calcClosestOpp(wm, teammate->pos(), th, r);
111
        addNormFeature(r, 0, maxR);
112 113 114 115 116 117 118
        detected_teammates++;
      }
    }
    // Add zero features for any missing teammates
    for (int i=detected_teammates; i<numTeammates; ++i) {
      addFeature(FEAT_INVALID);
    }
119 120 121 122
  } else { // If no opponents, add invalid features
    for (int i=0; i<numTeammates; ++i) {
      addFeature(FEAT_INVALID);
    }
123 124
  }

125
  // Features [9+2T - 9+3T]: open angle to teammates
126
  detected_teammates = 0;
127 128
  for (PlayerPtrCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) {
    const PlayerObject* teammate = *it;
129
    if (valid(teammate) && detected_teammates < numTeammates) {
130
      addNormFeature(calcLargestTeammateAngle(wm, self_pos, teammate->pos()),0,M_PI);
131 132 133 134 135
      detected_teammates++;
    }
  }
  // Add zero features for any missing teammates
  for (int i=detected_teammates; i<numTeammates; ++i) {
136
    addFeature(FEAT_INVALID);
137 138
  }

139
  // Features [9+3T - 9+6T]: x, y, unum of teammates
140
  detected_teammates = 0;
141 142
  for (PlayerPtrCont::const_iterator it=teammates.begin(); it != teammates.end(); ++it) {
    const PlayerObject* teammate = *it;
143
    if (valid(teammate) && detected_teammates < numTeammates) {
144
      if (playingOffense) {
145
        addNormFeature(teammate->pos().x, -tolerance_x, SP.pitchHalfLength() + tolerance_x);
146
      } else {
147
        addNormFeature(teammate->pos().x, -SP.pitchHalfLength()-tolerance_x, tolerance_x);
148
      }
149 150
      addNormFeature(teammate->pos().y, -tolerance_y - SP.pitchHalfWidth(), SP.pitchHalfWidth() + tolerance_y);
      addFeature(teammate->unum());
151
      detected_teammates++;
152 153
    }
  }
154 155
  // Add zero features for any missing teammates
  for (int i=detected_teammates; i<numTeammates; ++i) {
156 157 158
    addFeature(FEAT_INVALID);
    addFeature(FEAT_INVALID);
    addFeature(FEAT_INVALID);
159
  }
160

161
  // Features [9+6T - 9+6T+3O]: x, y, unum of opponents
162
  int detected_opponents = 0;
163 164
  for (PlayerPtrCont::const_iterator it = opponents.begin(); it != opponents.end(); ++it) {
    const PlayerObject* opponent = *it;
165
    if (valid(opponent) && detected_opponents < numOpponents) {
166
      if (playingOffense) {
167
        addNormFeature(opponent->pos().x, -tolerance_x, SP.pitchHalfLength() + tolerance_x);
168
      } else {
169
        addNormFeature(opponent->pos().x, -SP.pitchHalfLength()-tolerance_x, tolerance_x);
170
      }
171 172
      addNormFeature(opponent->pos().y, -tolerance_y - SP.pitchHalfWidth(), SP.pitchHalfWidth() + tolerance_y);
      addFeature(opponent->unum());
173 174 175 176 177 178 179 180 181 182
      detected_opponents++;
    }
  }
  // Add zero features for any missing opponents
  for (int i=detected_opponents; i<numOpponents; ++i) {
    addFeature(FEAT_INVALID);
    addFeature(FEAT_INVALID);
    addFeature(FEAT_INVALID);
  }

183
  if (last_action_status) {
184 185 186 187 188
    addFeature(FEAT_MAX);
  } else {
    addFeature(FEAT_MIN);
  }

189
  assert(featIndx == numFeatures);
190
  // checkFeatures();
191 192
  return feature_vec;
}
193 194 195 196 197 198 199 200 201 202

bool HighLevelFeatureExtractor::valid(const rcsc::PlayerObject* player) {
  if (!player) {return false;} //avoid segfaults
  const rcsc::Vector2D& pos = player->pos();
  if (!player->posValid()) {
    return false;
  }
  return pos.isValid();
}