Commit 763d3da5 authored by Shashank Suhas's avatar Shashank Suhas

Changes to master branch. Uncommented weights save/load code

parent 99aaa836
...@@ -84,3 +84,6 @@ soccerwindow2-prefix/ ...@@ -84,3 +84,6 @@ soccerwindow2-prefix/
log/ log/
bin/log/ bin/log/
high_level_sarsa_agent high_level_sarsa_agent
# Misc
example/sarsa_offense/
...@@ -163,17 +163,17 @@ void CMAC::increaseMinTrace(){ ...@@ -163,17 +163,17 @@ void CMAC::increaseMinTrace(){
void CMAC::read(char *fileName){ void CMAC::read(char *fileName){
std::cout<<"Not reading weights"<<std::endl; // std::cout<<"Not reading weights"<<std::endl;
// std::fstream file; std::fstream file;
// file.open(fileName, std::ios::in | std::ios::binary); file.open(fileName, std::ios::in | std::ios::binary);
// file.read((char *) weights, RL_MEMORY_SIZE * sizeof(double)); file.read((char *) weights, RL_MEMORY_SIZE * sizeof(double));
// unsigned long pos = file.tellg(); unsigned long pos = file.tellg();
// file.close(); file.close();
colTab->restore(fileName, pos);
// for(int i=0; i<RL_MEMORY_SIZE; i++) // for(int i=0; i<RL_MEMORY_SIZE; i++)
// std::cout<<weights[i]<<std::endl; // std::cout<<weights[i]<<std::endl;
// colTab->restore(fileName, pos);
} }
void CMAC::write(char *fileName){ void CMAC::write(char *fileName){
......
...@@ -107,25 +107,10 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -107,25 +107,10 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
"_" + std::to_string(numTMates + 1) + "_" + std::to_string(numTMates + 1) +
"_" + std::to_string(suffix); "_" + std::to_string(suffix);
wtFile = &s[0u]; wtFile = &s[0u];
// std::string filename1 = "early_game_model_" + std::to_string(port) +
// "_" + std::to_string(numTMates + 1) +
// "_" + std::to_string(suffix);
// std::string filename2 = "late_game_model_" + std::to_string(port) +
// "_" + std::to_string(numTMates + 1) +
// "_" + std::to_string(suffix);
// char *str1 = &filename1[0u];
// char *str2 = &filename2[0u];
double lambda = 0; double lambda = 0;
CMAC *fa = new CMAC(numF, numA, range, min, res); CMAC *fa = new CMAC(numF, numA, range, min, res);
// CMAC *fa1 = new CMAC(numF, numA, range, min, res);
// CMAC *fa2 = new CMAC(numF, numA, range, min, res);
// CMAC *fa = fa1;
SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa, wtFile, wtFile); SarsaAgent *sa = new SarsaAgent(numF, numA, learnR, eps, lambda, fa, wtFile, wtFile);
// SarsaAgent *sa1 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa1, str1, str1);
// SarsaAgent *sa2 = new SarsaAgent(numF, numA, learnR, eps, lambda, fa2, str2, str2);
// SarsaAgent *sa = sa1;
hfo::HFOEnvironment hfo; hfo::HFOEnvironment hfo;
hfo::status_t status; hfo::status_t status;
...@@ -141,9 +126,6 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -141,9 +126,6 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
// bool model_changed_flag = false; // bool model_changed_flag = false;
while (status == hfo::IN_GAME) { while (status == hfo::IN_GAME) {
const std::vector<float>& state_vec = hfo.getState(); const std::vector<float>& state_vec = hfo.getState();
// If has ball
// sleep(1);
// std::cout<<state_vec[numTMates]<<std::endl;
if(state_vec[5] == 1) { if(state_vec[5] == 1) {
if(action != -1) { if(action != -1) {
...@@ -151,12 +133,6 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -151,12 +133,6 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
sa->update(state, action, reward, discFac); sa->update(state, action, reward, discFac);
} }
// if(state_vec[numTMates] >= 0.2 && model_changed_flag == false)
// {
// sa = sa2;
// model_changed_flag = true;
// }
// Fill up state array // Fill up state array
purgeFeatures(state, state_vec, numTMates, oppPres); purgeFeatures(state, state_vec, numTMates, oppPres);
...@@ -182,16 +158,10 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR, ...@@ -182,16 +158,10 @@ void offenseAgent(int port, int numTMates, int numEpi, double learnR,
reward = getReward(status); reward = getReward(status);
sa->update(state, action, reward, discFac); sa->update(state, action, reward, discFac);
sa->endEpisode(); sa->endEpisode();
// sa1->endEpisode();
// sa2->endEpisode();
// sa = sa1;
// model_changed_flag = false;
} }
} }
delete sa, fa; delete sa, fa;
// delete sa1, sa2;
// delete fa1, fa2;
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment