mirror of
https://github.com/vcmi/vcmi.git
synced 2024-12-26 22:57:00 +02:00
[programming challenge, SSN] REPL, various "fixes"
This commit is contained in:
parent
a900fe71c8
commit
82a6520feb
@ -94,6 +94,8 @@ struct Example
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct SSN_Runner;
|
||||||
|
|
||||||
class Framework
|
class Framework
|
||||||
{
|
{
|
||||||
static CArtifactInstance *generateArtWithBonus(const Bonus &b);
|
static CArtifactInstance *generateArtWithBonus(const Bonus &b);
|
||||||
@ -113,6 +115,8 @@ public:
|
|||||||
|
|
||||||
static void buildLearningSet();
|
static void buildLearningSet();
|
||||||
static vector<Example> loadExamples(bool printInfo = true);
|
static vector<Example> loadExamples(bool printInfo = true);
|
||||||
|
|
||||||
|
friend SSN_Runner;
|
||||||
};
|
};
|
||||||
|
|
||||||
vector<string> Framework::getFileNames(const string &dirname, const std::string &ext)
|
vector<string> Framework::getFileNames(const string &dirname, const std::string &ext)
|
||||||
@ -149,9 +153,9 @@ vector<Example> Framework::loadExamples(bool printInfo)
|
|||||||
examples.push_back(ex);
|
examples.push_back(ex);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tlog0 << "Found " << examples.size() << " examples.\n";
|
||||||
if(printInfo)
|
if(printInfo)
|
||||||
{
|
{
|
||||||
tlog0 << "Found " << examples.size() << " examples.\n";
|
|
||||||
BOOST_FOREACH(auto &ex, examples)
|
BOOST_FOREACH(auto &ex, examples)
|
||||||
{
|
{
|
||||||
tlog0 << format("Battle on army %d for bonus %d of value %d has resultdiff %lf\n") % ex.i % ex.j % ex.k % ex.value;
|
tlog0 << format("Battle on army %d for bonus %d of value %d has resultdiff %lf\n") % ex.i % ex.j % ex.k % ex.value;
|
||||||
@ -471,11 +475,15 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
SSN();
|
SSN();
|
||||||
|
SSN(string filename);
|
||||||
~SSN();
|
~SSN();
|
||||||
|
|
||||||
//returns mse after learning
|
//returns mse after learning
|
||||||
double learn(const std::vector<Example> & input, const ParameterSet & params);
|
double learn(const std::vector<Example> & input, const ParameterSet & params);
|
||||||
|
double learn(bool adjustParams = false);
|
||||||
|
|
||||||
|
SSN::ParameterSet getBestParams(vector<Example> &trainingSet);
|
||||||
|
SSN::ParameterSet getBestParams();
|
||||||
double test(const std::vector<Example> & input)
|
double test(const std::vector<Example> & input)
|
||||||
{
|
{
|
||||||
auto td = getTrainingData(input);
|
auto td = getTrainingData(input);
|
||||||
@ -485,11 +493,17 @@ public:
|
|||||||
double run(const DuelParameters &dp, CArtifactInstance * inst);
|
double run(const DuelParameters &dp, CArtifactInstance * inst);
|
||||||
|
|
||||||
void save(const std::string &filename);
|
void save(const std::string &filename);
|
||||||
|
void load(const std::string &filename);
|
||||||
};
|
};
|
||||||
|
|
||||||
SSN::SSN()
|
SSN::SSN()
|
||||||
{}
|
{}
|
||||||
|
|
||||||
|
SSN::SSN(string filename)
|
||||||
|
{
|
||||||
|
load(filename);
|
||||||
|
}
|
||||||
|
|
||||||
void SSN::init(const ParameterSet & params)
|
void SSN::init(const ParameterSet & params)
|
||||||
{
|
{
|
||||||
const float learning_rate = 0.7f;
|
const float learning_rate = 0.7f;
|
||||||
@ -517,7 +531,7 @@ double SSN::run(const DuelParameters &dp, CArtifactInstance * inst)
|
|||||||
double * input = genSSNinput(dp.sides[0], inst, dp.bfieldType, dp.terType);
|
double * input = genSSNinput(dp.sides[0], inst, dp.bfieldType, dp.terType);
|
||||||
double * out = net.run(input);
|
double * out = net.run(input);
|
||||||
double ret = *out;
|
double ret = *out;
|
||||||
free(out);
|
//free(out);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
@ -539,7 +553,6 @@ double SSN::learn(const std::vector<Example> & input, const ParameterSet & param
|
|||||||
net.set_callback(ANNCallback, NULL);
|
net.set_callback(ANNCallback, NULL);
|
||||||
net.train_on_data(*td, 1000, 1000, 0.01);
|
net.train_on_data(*td, 1000, 1000, 0.01);
|
||||||
|
|
||||||
|
|
||||||
// int exNum = 130;
|
// int exNum = 130;
|
||||||
//
|
//
|
||||||
// for(int exNum =0; exNum<input.size(); ++exNum)
|
// for(int exNum =0; exNum<input.size(); ++exNum)
|
||||||
@ -553,6 +566,25 @@ double SSN::learn(const std::vector<Example> & input, const ParameterSet & param
|
|||||||
return net.test_data(*td);
|
return net.test_data(*td);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
double SSN::learn(bool adjustParams/* = false*/)
|
||||||
|
{
|
||||||
|
|
||||||
|
cout << "Loading examples...\n";
|
||||||
|
auto trainingSet = Framework::loadExamples(false);
|
||||||
|
cout << "Looking for best learning parameters...\n";
|
||||||
|
|
||||||
|
|
||||||
|
auto params = adjustParams ? getBestParams(trainingSet) : getBestParams();
|
||||||
|
|
||||||
|
cout << "Learning...\n";
|
||||||
|
|
||||||
|
//saving of best network
|
||||||
|
double finalLmse = learn(trainingSet, params);
|
||||||
|
cout << "Learning done, LMSE=" << finalLmse << endl;
|
||||||
|
save("last_network.net");
|
||||||
|
return finalLmse;
|
||||||
|
}
|
||||||
|
|
||||||
double * SSN::genSSNinput(const DuelParameters::SideSettings & dp, CArtifactInstance * art, si32 bfieldType, si32 terType)
|
double * SSN::genSSNinput(const DuelParameters::SideSettings & dp, CArtifactInstance * art, si32 bfieldType, si32 terType)
|
||||||
{
|
{
|
||||||
double * ret = new double[num_input];
|
double * ret = new double[num_input];
|
||||||
@ -633,16 +665,17 @@ FANN::training_data * SSN::getTrainingData( const std::vector<Example> &input )
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SSNRun()
|
void SSN::load(const std::string &filename)
|
||||||
|
{
|
||||||
|
net.create_from_file(filename);
|
||||||
|
cout << "Loaded a network from file " << filename << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
SSN::ParameterSet SSN::getBestParams(vector<Example> &trainingSet)
|
||||||
{
|
{
|
||||||
//Framework::buildLearningSet();
|
|
||||||
double percentToTrain = 0.8;
|
double percentToTrain = 0.8;
|
||||||
|
|
||||||
auto trainingSet = Framework::loadExamples(false);
|
|
||||||
|
|
||||||
std::vector<Example> testSet;
|
std::vector<Example> testSet;
|
||||||
|
|
||||||
|
|
||||||
for(int i=0, maxi = trainingSet.size()*(1-percentToTrain); i<maxi; ++i)
|
for(int i=0, maxi = trainingSet.size()*(1-percentToTrain); i<maxi; ++i)
|
||||||
{
|
{
|
||||||
int ind = rand()%trainingSet.size();
|
int ind = rand()%trainingSet.size();
|
||||||
@ -650,9 +683,6 @@ void SSNRun()
|
|||||||
trainingSet.erase(trainingSet.begin() + ind);
|
trainingSet.erase(trainingSet.begin() + ind);
|
||||||
}
|
}
|
||||||
|
|
||||||
SSN network;
|
|
||||||
|
|
||||||
|
|
||||||
SSN::ParameterSet bestParams;
|
SSN::ParameterSet bestParams;
|
||||||
double besttMSE = 1e10;
|
double besttMSE = 1e10;
|
||||||
|
|
||||||
@ -661,12 +691,6 @@ void SSNRun()
|
|||||||
|
|
||||||
FANN::activation_function_enum possibleFuns[] = {FANN::SIGMOID_SYMMETRIC_STEPWISE, FANN::LINEAR,
|
FANN::activation_function_enum possibleFuns[] = {FANN::SIGMOID_SYMMETRIC_STEPWISE, FANN::LINEAR,
|
||||||
FANN::SIGMOID, FANN::SIGMOID_STEPWISE, FANN::SIGMOID_SYMMETRIC};
|
FANN::SIGMOID, FANN::SIGMOID_STEPWISE, FANN::SIGMOID_SYMMETRIC};
|
||||||
//
|
|
||||||
// bestParams.actSteepHidden = 0.346;
|
|
||||||
// bestParams.actSteepnessOutput = 0.449;
|
|
||||||
// bestParams.hiddenActFun = FANN::SIGMOID_SYMMETRIC;
|
|
||||||
// bestParams.outActFun = FANN::SIGMOID_SYMMETRIC;
|
|
||||||
// bestParams.neuronsInHidden = 23;
|
|
||||||
|
|
||||||
for(int i=0; i<5000; i += 1)
|
for(int i=0; i<5000; i += 1)
|
||||||
{
|
{
|
||||||
@ -677,9 +701,9 @@ void SSNRun()
|
|||||||
ps.hiddenActFun = possibleFuns[rand()%ARRAY_COUNT(possibleFuns)];
|
ps.hiddenActFun = possibleFuns[rand()%ARRAY_COUNT(possibleFuns)];
|
||||||
ps.outActFun = possibleFuns[rand()%ARRAY_COUNT(possibleFuns)];
|
ps.outActFun = possibleFuns[rand()%ARRAY_COUNT(possibleFuns)];
|
||||||
|
|
||||||
double lmse = network.learn(trainingSet, ps);
|
double lmse = learn(trainingSet, ps);
|
||||||
|
|
||||||
double tmse = network.test(testSet);
|
double tmse = test(testSet);
|
||||||
if(tmse < besttMSE)
|
if(tmse < besttMSE)
|
||||||
{
|
{
|
||||||
besttMSE = tmse;
|
besttMSE = tmse;
|
||||||
@ -688,12 +712,199 @@ void SSNRun()
|
|||||||
|
|
||||||
cout << "hid:\t" << i << " lmse:\t" << lmse << " tmse:\t" << tmse << std::endl;
|
cout << "hid:\t" << i << " lmse:\t" << lmse << " tmse:\t" << tmse << std::endl;
|
||||||
}
|
}
|
||||||
//saving of best network
|
|
||||||
double debugMSE = network.learn(trainingSet, bestParams);
|
|
||||||
|
|
||||||
network.save("network_config_file.net");
|
return bestParams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SSN::ParameterSet SSN::getBestParams()
|
||||||
|
{
|
||||||
|
// bestParams.actSteepHidden = 0.346;
|
||||||
|
// bestParams.actSteepnessOutput = 0.449;
|
||||||
|
// bestParams.hiddenActFun = FANN::SIGMOID_SYMMETRIC;
|
||||||
|
// bestParams.outActFun = FANN::SIGMOID_SYMMETRIC;
|
||||||
|
// bestParams.neuronsInHidden = 23;
|
||||||
|
|
||||||
|
|
||||||
|
SSN::ParameterSet params;
|
||||||
|
params.actSteepHidden = 1.18;
|
||||||
|
params.actSteepnessOutput = 1.26;
|
||||||
|
params.hiddenActFun = FANN::SIGMOID_STEPWISE;
|
||||||
|
params.outActFun = FANN::SIGMOID_SYMMETRIC;
|
||||||
|
params.neuronsInHidden = 47;
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SSN_Runner
|
||||||
|
{
|
||||||
|
unique_ptr<SSN> ssn;
|
||||||
|
ArmyDescriptor ad;
|
||||||
|
|
||||||
|
void printHelp()
|
||||||
|
{
|
||||||
|
const char *cmds[] = {"help - prints this info", "create - creates a new ANN, needs to be learned then", "load <file> - loads ANN from file", "save <file> - saves current ANN to file", "learn - runs learning process using examples set", "ask <id> - evaluates given art", "exit - closes application",
|
||||||
|
"army clear - removes current army information", "army add <id> <count> - adds creature to army", "army remove <pos> - removes stack from position",
|
||||||
|
"army print - prints current army state", "army random - generates random army"};
|
||||||
|
cout << "Available commands:\n";
|
||||||
|
BOOST_FOREACH(auto cmd, cmds)
|
||||||
|
cout << "\t" << cmd << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
int run()
|
||||||
|
{
|
||||||
|
cout << "Welcome to the ANN interactive mode!\n";
|
||||||
|
printHelp();
|
||||||
|
|
||||||
|
while(1)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
cout << "Please enter your command and press return.\n> ";
|
||||||
|
stringstream ss;
|
||||||
|
string input;
|
||||||
|
getline(cin, input);
|
||||||
|
ss.str(input);
|
||||||
|
|
||||||
|
string command, secondWord;
|
||||||
|
ss >> command >> secondWord;
|
||||||
|
|
||||||
|
if(command == "exit")
|
||||||
|
{
|
||||||
|
cout << "Ending...\n";
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
else if(command == "load")
|
||||||
|
{
|
||||||
|
if(secondWord.empty())
|
||||||
|
secondWord = "last_network.net";
|
||||||
|
|
||||||
|
ssn = unique_ptr<SSN>(new SSN(secondWord));
|
||||||
|
}
|
||||||
|
else if(command == "create")
|
||||||
|
{
|
||||||
|
ssn = unique_ptr<SSN>(new SSN());
|
||||||
|
cout << "Network successfully created. It still needs to be learnt.\n";
|
||||||
|
}
|
||||||
|
else if(command == "help")
|
||||||
|
{
|
||||||
|
printHelp();
|
||||||
|
}
|
||||||
|
|
||||||
|
else if(command == "army" && secondWord.size())
|
||||||
|
{
|
||||||
|
if(secondWord == "clear")
|
||||||
|
{
|
||||||
|
ad.clear();
|
||||||
|
cout << "Army is now empty.\n";
|
||||||
|
}
|
||||||
|
if(secondWord == "print")
|
||||||
|
{
|
||||||
|
cout << "Army contains " << ad.size() << " creatures.\n";
|
||||||
|
BOOST_FOREACH(auto &itr, ad)
|
||||||
|
{
|
||||||
|
cout << itr.first << " => " << itr.second.count << " of " << itr.second.type->namePl << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(secondWord == "erase")
|
||||||
|
{
|
||||||
|
int slot;
|
||||||
|
ss >> slot;
|
||||||
|
if(ad.find(slot) != ad.end())
|
||||||
|
{
|
||||||
|
ad.erase(slot);
|
||||||
|
cout << "Slot " << slot << " successfully erased.\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(secondWord == "add")
|
||||||
|
{
|
||||||
|
int id, count;
|
||||||
|
ss >> id >> count;
|
||||||
|
int i = 0;
|
||||||
|
if(id < 0 || id >= 118)
|
||||||
|
{
|
||||||
|
throw std::runtime_error("Id has to be in <0,118>");
|
||||||
|
}
|
||||||
|
if(count <= 0)
|
||||||
|
{
|
||||||
|
throw std::runtime_error("Count has to be > 0");
|
||||||
|
}
|
||||||
|
|
||||||
|
while(ad.find(i++) != ad.end());
|
||||||
|
if(i >= ARMY_SIZE)
|
||||||
|
{
|
||||||
|
tlog1 << "Cannot add stack, army is full!\n";
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
ad[i] = CStackBasicDescriptor(id, count);
|
||||||
|
tlog0 << "Creature successfully added to slot " << i << endl;;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(secondWord == "random")
|
||||||
|
{
|
||||||
|
srand(time(0));
|
||||||
|
ad.clear();
|
||||||
|
int stacks = rand() % 7 + 1;
|
||||||
|
for(int i = 0; i < stacks; i++)
|
||||||
|
{
|
||||||
|
CCreature *c = VLC->creh->creatures[rand() % 118];
|
||||||
|
ad[i] = CStackBasicDescriptor(c, c->growth);
|
||||||
|
}
|
||||||
|
cout << "Generated random army of " << stacks << " creatures.\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else if(!ssn)
|
||||||
|
{
|
||||||
|
cout << "Error: you need to create or load ANN from file first!\n";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
else if(command == "learn")
|
||||||
|
{
|
||||||
|
ssn->learn();
|
||||||
|
}
|
||||||
|
else if(command == "save")
|
||||||
|
{
|
||||||
|
ssn->save(secondWord);
|
||||||
|
}
|
||||||
|
else if(command == "ask")
|
||||||
|
{
|
||||||
|
int artid = boost::lexical_cast<int>(secondWord);
|
||||||
|
CArtifact *art = VLC->arth->artifacts.at(artid);
|
||||||
|
|
||||||
|
DuelParameters dp = Framework::generateDuel(ad);
|
||||||
|
|
||||||
|
CArtifactInstance * artInst = new CArtifactInstance(art);
|
||||||
|
auto bonuses = art->getBonuses([](const Bonus*){return true;});
|
||||||
|
if(!bonuses->size())
|
||||||
|
{
|
||||||
|
tlog1 << "This artifact deosn't provide any bonuses. Please pick another one.";
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
BOOST_FOREACH(auto b, *bonuses)
|
||||||
|
artInst->addNewBonus(new Bonus(*b));
|
||||||
|
|
||||||
|
|
||||||
|
auto val = ssn->run(dp, artInst);
|
||||||
|
cout << "ANN rates " << art->Name() << " to value = " << val << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
tlog1 << "Unknown command \""<<command <<"\"!\n";
|
||||||
|
}
|
||||||
|
catch(std::exception &e)
|
||||||
|
{
|
||||||
|
tlog1 << "Encountered error: " << e.what() << endl;
|
||||||
|
}
|
||||||
|
catch(...)
|
||||||
|
{
|
||||||
|
tlog1 << "Encountered unknown error!" << endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
std::cout << "VCMI Odpalarka\nMy path: " << argv[0] << std::endl;
|
std::cout << "VCMI Odpalarka\nMy path: " << argv[0] << std::endl;
|
||||||
@ -764,7 +975,8 @@ int main(int argc, char **argv)
|
|||||||
VLC = new LibClasses();
|
VLC = new LibClasses();
|
||||||
VLC->init();
|
VLC->init();
|
||||||
|
|
||||||
SSNRun();
|
SSN_Runner runner;
|
||||||
|
runner.run();
|
||||||
|
|
||||||
return EXIT_SUCCESS;
|
return EXIT_SUCCESS;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user