1
0
mirror of https://github.com/vcmi/vcmi.git synced 2025-03-17 20:58:07 +02:00

[programming challenge, SSN]

* caching and saving desired outputs
* saving of ANN
* full learning, not just one epoch
* stub of reporting of learning function (ANNCallback)
This commit is contained in:
mateuszb 2012-05-27 11:57:15 +00:00
parent 4ae3c8c8f1
commit 867d01dc34

View File

@ -1,6 +1,7 @@
//#include "../global.h" //#include "../global.h"
#include "StdInc.h" #include "StdInc.h"
#include "../lib/VCMI_Lib.h" #include "../lib/VCMI_Lib.h"
#include "boost/tuple/tuple.hpp"
namespace po = boost::program_options; namespace po = boost::program_options;
@ -204,7 +205,16 @@ double runSSN(FANN::neural_net & net, const DuelParameters dp, CArtifactInstance
return ret; return ret;
} }
void learnSSN(FANN::neural_net & net, const std::vector<std::pair<DuelParameters, CArtifactInstance *> > & input) int ANNCallback(FANN::neural_net &net, FANN::training_data &train,
unsigned int max_epochs, unsigned int epochs_between_reports,
float desired_error, unsigned int epochs, void *user_data)
{
//cout << "Epochs " << setw(8) << epochs << ". "
// << "Current Error: " << left << net.get_MSE() << right << endl;
return 0;
}
void learnSSN(FANN::neural_net & net, const std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > & input)
{ {
FANN::training_data td; FANN::training_data td;
@ -212,13 +222,13 @@ void learnSSN(FANN::neural_net & net, const std::vector<std::pair<DuelParameters
double ** outputs = new double *[input.size()]; double ** outputs = new double *[input.size()];
for(int i=0; i<input.size(); ++i) for(int i=0; i<input.size(); ++i)
{ {
inputs[i] = genSSNinput(input[i].first, input[i].second); inputs[i] = genSSNinput(input[i].get<0>(), input[i].get<1>());
outputs[i] = new double; outputs[i] = new double;
*(outputs[i]) = rateArt(input[i].first, input[i].second); *(outputs[i]) = input[i].get<2>();
} }
td.set_train_data(input.size(), num_input, inputs, 1, outputs); td.set_train_data(input.size(), num_input, inputs, 1, outputs);
net.set_callback(ANNCallback, NULL);
net.train_epoch(td); net.train_on_data(td, 1000, 1000, 0.01);
} }
void initNet(FANN::neural_net & ret) void initNet(FANN::neural_net & ret)
@ -317,15 +327,22 @@ void SSNRun()
auto arts = genArts(btt); auto arts = genArts(btt);
//evaluate //evaluate
std::vector<std::pair<DuelParameters, CArtifactInstance *> > setups; std::vector<boost::tuple<DuelParameters, CArtifactInstance *, double> > setups;
std::ofstream desOuts("desiredOuts.dat");
for(int i=0; i<dps.size(); ++i) for(int i=0; i<dps.size(); ++i)
{ {
for(int j=0; j<arts.size(); ++j) for(int j=0; j<arts.size(); ++j)
{ {
setups.push_back(std::make_pair(dps[i], arts[j])); setups.push_back(boost::make_tuple(dps[i], arts[j], rateArt(dps[i], arts[i])));
desOuts << (*setups.rbegin()).get<2>() << " ";
} }
desOuts << std::endl;
} }
learnSSN(network, setups); learnSSN(network, setups);
network.save("network_config_file.net");
} }
int main(int argc, char **argv) int main(int argc, char **argv)