Commit ab59293b authored by Matthew Hausknecht's avatar Matthew Hausknecht

Added infrastructure for choosing between the high/low level feature set.

parent 8c1fef0d
......@@ -10,8 +10,9 @@ using namespace std;
int main() {
// Create the HFO environment
HFOEnvironment hfo;
// Connect the agent's server
hfo.connectToAgentServer(6000);
// Connect the agent's server on the given port with the given
// feature set. See possible feature sets in src/HFO.hpp.
hfo.connectToAgentServer(6000, LOW_LEVEL_FEATURE_SET);
// Play 5 episodes
for (int episode=0; episode<5; episode++) {
hfo_status_t status = IN_GAME;
......@@ -19,7 +20,7 @@ int main() {
// Grab the vector of state features for the current state
const std::vector<float>& feature_vec = hfo.getState();
// Create a dash action
Action a = {DASH, 100., 0.};
Action a = {DASH, 0., 0.};
// Perform the dash and recieve the current game status
status = hfo.act(a);
}
......
......@@ -12,8 +12,9 @@ if __name__ == '__main__':
exit()
# Create the HFO Environment
hfo = hfo.HFOEnvironment()
# Connect to the agent server
hfo.connectToAgentServer()
# Connect to the agent server on port 6000 with the specified
# feature set. See feature sets in hfo.py/hfo.hpp.
hfo.connectToAgentServer(6000, HFO_Features.HIGH_LEVEL_FEATURE_SET)
# Play 5 episodes
for episode in xrange(5):
status = HFO_Status.IN_GAME
......@@ -21,7 +22,7 @@ if __name__ == '__main__':
# Grab the state features from the environment
features = hfo.getState()
# Take an action and get the current game status
status = hfo.act((HFO_Actions.DASH, 100, 0))
status = hfo.act((HFO_Actions.DASH, 0, 0))
print 'Episode', episode, 'ended with',
# Check what the outcome of the episode was
if status == HFO_Status.GOAL:
......
import socket, struct, thread, time
class HFO_Features:
''' An enum of the possible HFO feature sets. For descriptions see
https://github.com/mhauskn/HFO/blob/master/doc/manual.pdf
'''
LOW_LEVEL_FEATURE_SET, HIGH_LEVEL_FEATURE_SET = range(2)
class HFO_Actions:
''' An enum of the possible HFO actions
......@@ -31,7 +38,8 @@ class HFOEnvironment(object):
self.numFeatures = None # Given by the server in handshake
self.features = None # The state features
def connectToAgentServer(self, server_port=6000):
def connectToAgentServer(self, server_port=6000,
feature_set=HFO_Features.HIGH_LEVEL_FEATURE_SET):
'''Connect to the server that controls the agent on the specified port. '''
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
print '[Agent Client] Connecting to Agent Server on port', server_port
......@@ -44,7 +52,7 @@ class HFOEnvironment(object):
else:
break
print '[Agent Client] Connected'
self.handshakeAgentServer()
self.handshakeAgentServer(feature_set)
# Get the initial state
state_data = self.socket.recv(struct.calcsize('f')*self.numFeatures)
if not state_data:
......@@ -53,7 +61,7 @@ class HFOEnvironment(object):
exit(1)
self.features = struct.unpack('f'*self.numFeatures, state_data)
def handshakeAgentServer(self):
def handshakeAgentServer(self, feature_set):
'''Handshake with the agent's server. '''
# Recieve float 123.2345
data = self.socket.recv(struct.calcsize("f"))
......@@ -61,6 +69,8 @@ class HFOEnvironment(object):
assert abs(f - 123.2345) < 1e-4, "Float handshake failed"
# Send float 5432.321
self.socket.send(struct.pack("f", 5432.321))
# Send the feature set request
self.socket.send(struct.pack("i", feature_set))
# Recieve the number of features
data = self.socket.recv(struct.calcsize("i"))
self.numFeatures = struct.unpack("i", data)[0]
......
......@@ -24,7 +24,8 @@ HFOEnvironment::~HFOEnvironment() {
close(sockfd);
}
void HFOEnvironment::connectToAgentServer(int server_port) {
void HFOEnvironment::connectToAgentServer(int server_port,
feature_set_t feature_set) {
std::cout << "[Agent Client] Connecting to Agent Server on port "
<< server_port << std::endl;
sockfd = socket(AF_INET, SOCK_STREAM, 0);
......@@ -49,7 +50,7 @@ void HFOEnvironment::connectToAgentServer(int server_port) {
sleep(1);
}
std::cout << "[Agent Client] Connected" << std::endl;
handshakeAgentServer();
handshakeAgentServer(feature_set);
// Get the initial game state
feature_vec.resize(numFeatures);
if (recv(sockfd, &(feature_vec.front()), numFeatures*sizeof(float), 0) < 0) {
......@@ -57,7 +58,7 @@ void HFOEnvironment::connectToAgentServer(int server_port) {
}
}
void HFOEnvironment::handshakeAgentServer() {
void HFOEnvironment::handshakeAgentServer(feature_set_t feature_set) {
// Recieve float 123.2345
float f;
if (recv(sockfd, &f, sizeof(float), 0) < 0) {
......@@ -72,6 +73,10 @@ void HFOEnvironment::handshakeAgentServer() {
if (send(sockfd, &f, sizeof(float), 0) < 0) {
error("[Agent Client] ERROR sending from socket");
}
// Send the feature set request
if (send(sockfd, &feature_set, sizeof(int), 0) < 0) {
error("[Agent Client] ERROR sending from socket");
}
// Recieve the number of features
if (recv(sockfd, &numFeatures, sizeof(int), 0) < 0) {
error("[Agent Client] ERROR recv from socket");
......
......@@ -3,6 +3,14 @@
#include <vector>
// For descriptions of the different feature sets see
// https://github.com/mhauskn/HFO/blob/master/doc/manual.pdf
enum feature_set_t
{
LOW_LEVEL_FEATURE_SET,
HIGH_LEVEL_FEATURE_SET
};
// The actions available to the agent
enum action_t
{
......@@ -39,7 +47,8 @@ class HFOEnvironment {
~HFOEnvironment();
// Connect to the server that controls the agent on the specified port.
void connectToAgentServer(int server_port=6008);
void connectToAgentServer(int server_port=6000,
feature_set_t feature_set=HIGH_LEVEL_FEATURE_SET);
// Get the current state of the domain. Returns a reference to feature_vec.
const std::vector<float>& getState();
......@@ -54,7 +63,7 @@ class HFOEnvironment {
// Handshake with the agent server to ensure data is being correctly
// passed. Also sets the number of features to expect.
virtual void handshakeAgentServer();
virtual void handshakeAgentServer(feature_set_t feature_set);
};
#endif
......@@ -137,7 +137,10 @@ Agent::Agent()
M_action_generator(createActionGenerator()),
lastTrainerMessageTime(-1),
server_port(6008),
server_running(false)
server_running(false),
num_teammates(-1),
num_opponents(-1),
playing_offense(false)
{
boost::shared_ptr< AudioMemory > audio_memory( new AudioMemory );
......@@ -198,13 +201,11 @@ bool Agent::initImpl(CmdLineParser & cmd_parser) {
// read additional options
result &= Strategy::instance().init(cmd_parser);
int numTeammates, numOpponents;
bool playingOffense;
rcsc::ParamMap my_params("Additional options");
my_params.add()
("numTeammates", "", &numTeammates)
("numOpponents", "", &numOpponents)
("playingOffense", "", &playingOffense)
("numTeammates", "", &num_teammates)
("numOpponents", "", &num_opponents)
("playingOffense", "", &playing_offense)
("serverPort", "", &server_port);
cmd_parser.parse(my_params);
if (cmd_parser.count("help") > 0) {
......@@ -245,11 +246,8 @@ bool Agent::initImpl(CmdLineParser & cmd_parser) {
<< std::endl;
}
assert(numTeammates >= 0);
assert(numOpponents >= 0);
feature_extractor = new LowLevelFeatureExtractor(numTeammates,
numOpponents,
playingOffense);
assert(num_teammates >= 0);
assert(num_opponents >= 0);
return true;
}
......@@ -292,6 +290,13 @@ void Agent::clientHandshake() {
if (abs(f - 5432.321) > 1e-4) {
error("[Agent Server] Handshake failed. Improper float recieved.");
}
// Recieve the feature set to use
feature_set_t feature_set;
if (recv(newsockfd, &feature_set, sizeof(int), 0) < 0) {
error("[Agent Server] ERROR recv from socket");
}
// Create the corresponding FeatureExtractor
feature_extractor = getFeatureExtractor(feature_set);
// Send the number of features
int numFeatures = feature_extractor->getNumFeatures();
assert(numFeatures > 0);
......@@ -309,6 +314,27 @@ void Agent::clientHandshake() {
std::cout << "[Agent Server] Handshake complete" << std::endl;
}
FeatureExtractor* Agent::getFeatureExtractor(feature_set_t feature_set_indx) {
if (feature_extractor != NULL) {
delete feature_extractor;
}
switch (feature_set_indx) {
case LOW_LEVEL_FEATURE_SET:
return new LowLevelFeatureExtractor(num_teammates, num_opponents,
playing_offense);
break;
case HIGH_LEVEL_FEATURE_SET:
return new HighLevelFeatureExtractor(num_teammates, num_opponents,
playing_offense);
break;
default:
std::cerr << "[Feature Extractor] ERROR Unrecognized Feature set index: "
<< feature_set_indx << std::endl;
exit(1);
}
}
hfo_status_t Agent::getGameStatus() {
hfo_status_t game_status = IN_GAME;
if (audioSensor().trainerMessageTime().cycle() > lastTrainerMessageTime) {
......
......@@ -69,12 +69,17 @@ protected:
// Transmit information to the client and ensure it can recieve.
void clientHandshake();
// Returns the feature extractor corresponding to the feature_set_t
FeatureExtractor* getFeatureExtractor(feature_set_t feature_set);
protected:
FeatureExtractor* feature_extractor;
long lastTrainerMessageTime; // Last time the trainer sent a message
int server_port; // Port to start the server on
bool server_running; // Is the server running?
int sockfd, newsockfd; // Server sockets
int num_teammates, num_opponents;
bool playing_offense;
private:
bool doPreprocess();
......
......@@ -5,6 +5,12 @@
#include "feature_extractor.h"
#include <vector>
/**
* This feature extractor creates the high level feature set used by
* Barrett et al.
* (http://www.cs.utexas.edu/~sbarrett/publications/details-THESIS14-Barrett.html)
* pages 159-160.
*/
class HighLevelFeatureExtractor : public FeatureExtractor {
public:
HighLevelFeatureExtractor(int num_teammates, int num_opponents,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment