diff --git a/Odpalarka/main.cpp b/Odpalarka/main.cpp index 5502f2ae2..e2158aae3 100644 --- a/Odpalarka/main.cpp +++ b/Odpalarka/main.cpp @@ -1,6 +1,7 @@ //#include "../global.h" #include "StdInc.h" #include "../lib/VCMI_Lib.h" +#include "boost/tuple/tuple.hpp" namespace po = boost::program_options; @@ -204,7 +205,16 @@ double runSSN(FANN::neural_net & net, const DuelParameters dp, CArtifactInstance return ret; } -void learnSSN(FANN::neural_net & net, const std::vector > & input) +int ANNCallback(FANN::neural_net &net, FANN::training_data &train, + unsigned int max_epochs, unsigned int epochs_between_reports, + float desired_error, unsigned int epochs, void *user_data) +{ + //cout << "Epochs " << setw(8) << epochs << ". " + // << "Current Error: " << left << net.get_MSE() << right << endl; + return 0; +} + +void learnSSN(FANN::neural_net & net, const std::vector > & input) { FANN::training_data td; @@ -212,13 +222,13 @@ void learnSSN(FANN::neural_net & net, const std::vector(), input[i].get<1>()); outputs[i] = new double; - *(outputs[i]) = rateArt(input[i].first, input[i].second); + *(outputs[i]) = input[i].get<2>(); } td.set_train_data(input.size(), num_input, inputs, 1, outputs); - - net.train_epoch(td); + net.set_callback(ANNCallback, NULL); + net.train_on_data(td, 1000, 1000, 0.01); } void initNet(FANN::neural_net & ret) @@ -317,15 +327,22 @@ void SSNRun() auto arts = genArts(btt); //evaluate - std::vector > setups; + std::vector > setups; + + std::ofstream desOuts("desiredOuts.dat"); + for(int i=0; i() << " "; } + desOuts << std::endl; } + learnSSN(network, setups); + network.save("network_config_file.net"); } int main(int argc, char **argv)