2011-11-10 09:21:27 +00:00
//#include "../global.h"
2012-05-12 17:59:32 +00:00
# include "StdInc.h"
# include "../lib/VCMI_Lib.h"
2011-10-20 23:29:32 +00:00
namespace po = boost : : program_options ;
2012-05-28 11:03:20 +00:00
namespace fs = boost : : filesystem ;
using namespace std ;
2011-10-20 23:29:32 +00:00
2012-05-30 21:00:23 +00:00
# include <boost/random.hpp>
2012-05-13 14:51:13 +00:00
//FANN
2012-05-14 17:57:22 +00:00
# include <doublefann.h>
2012-05-13 14:51:13 +00:00
# include <fann_cpp.h>
2012-05-12 17:59:32 +00:00
std : : string leftAI , rightAI , battle , results , logsDir ;
bool withVisualization = false ;
std : : string servername ;
std : : string runnername ;
extern DLL_EXPORT LibClasses * VLC ;
2012-05-30 17:32:05 +00:00
typedef std : : map < int , CArtifactInstance * > TArtSet ;
namespace Utilities
{
std : : string addQuotesIfNeeded ( const std : : string & s )
{
if ( s . find_first_of ( ' ' ) ! = std : : string : : npos )
return " \" " + s + " \" " ;
return s ;
}
void prog_help ( )
{
std : : cout < < " If run without args, then StupidAI will be run on b1.json. \n " ;
}
string toString ( int i )
{
return boost : : lexical_cast < string > ( i ) ;
}
string describeBonus ( const Bonus & b )
{
return " + " + toString ( b . val ) + " _to_ " + bonusTypeToString ( b . type ) + " _sub " + toString ( b . subtype ) ;
}
}
using namespace Utilities ;
2012-05-28 11:03:20 +00:00
struct Example
{
//ANN input
DuelParameters dp ;
CArtifactInstance * art ;
//ANN expected output
double value ;
//other
std : : string description ;
int i , j , k ; //helper values for identification
Example ( ) { }
Example ( const DuelParameters & Dp , CArtifactInstance * Art , double Value )
: dp ( Dp ) , art ( Art ) , value ( Value )
{ }
inline bool operator < ( const Example & rhs ) const
{
if ( k < rhs . k )
return true ;
if ( k > rhs . k )
return false ;
if ( j < rhs . j )
return true ;
if ( j > rhs . j )
return false ;
if ( i < rhs . i )
return true ;
if ( i > rhs . i )
return false ;
return false ;
}
bool operator = = ( const Example & rhs ) const
{
return rhs . i = = i & & rhs . j = = j & & rhs . k = = k ;
}
template < typename Handler > void serialize ( Handler & h , const int version )
{
h & dp & art & value & description & i & j & k ;
}
} ;
2012-05-31 20:42:10 +00:00
struct SSN_Runner ;
2012-05-30 17:32:05 +00:00
class Framework
{
static CArtifactInstance * generateArtWithBonus ( const Bonus & b ) ;
static DuelParameters generateDuel ( const ArmyDescriptor & ad ) ; //generates simple duel where both sides have given army
static void runCommand ( const std : : string & command , const std : : string & name , const std : : string & logsDir = " " ) ;
static double playBattle ( const DuelParameters & dp ) ;
static double cmpArtSets ( DuelParameters dp , TArtSet setL , TArtSet setR ) ;
static double rateArt ( const DuelParameters dp , CArtifactInstance * inst ) ; //rates given artifact
static int theLastN ( ) ;
static vector < string > getFileNames ( const string & dirname = " ./examples/ " , const std : : string & ext = " example " ) ;
static vector < ArmyDescriptor > learningArmies ( ) ;
static vector < Bonus > learningBonuses ( ) ;
public :
Framework ( ) ;
~ Framework ( ) ;
static void buildLearningSet ( ) ;
static vector < Example > loadExamples ( bool printInfo = true ) ;
2012-05-31 20:42:10 +00:00
friend SSN_Runner ;
2012-05-30 17:32:05 +00:00
} ;
vector < string > Framework : : getFileNames ( const string & dirname , const std : : string & ext )
2012-05-28 11:03:20 +00:00
{
vector < string > ret ;
if ( ! fs : : exists ( dirname ) )
{
tlog1 < < " Cannot find " < < dirname < < " directory! Will attempt creating it. \n " ;
fs : : create_directory ( dirname ) ;
}
fs : : path tie ( dirname ) ;
fs : : directory_iterator end_iter ;
for ( fs : : directory_iterator file ( tie ) ; file ! = end_iter ; + + file )
{
if ( fs : : is_regular_file ( file - > status ( ) )
& & boost : : ends_with ( file - > path ( ) . filename ( ) , ext ) )
{
ret . push_back ( file - > path ( ) . string ( ) ) ;
}
}
return ret ;
}
2012-05-30 17:32:05 +00:00
vector < Example > Framework : : loadExamples ( bool printInfo )
2012-05-28 11:03:20 +00:00
{
std : : vector < Example > examples ;
BOOST_FOREACH ( auto fname , getFileNames ( " ./examples/ " , " example " ) )
{
CLoadFile loadf ( fname ) ;
Example ex ;
loadf > > ex ;
examples . push_back ( ex ) ;
}
2012-05-31 20:42:10 +00:00
tlog0 < < " Found " < < examples . size ( ) < < " examples. \n " ;
2012-05-28 11:03:20 +00:00
if ( printInfo )
{
BOOST_FOREACH ( auto & ex , examples )
{
tlog0 < < format ( " Battle on army %d for bonus %d of value %d has resultdiff %lf \n " ) % ex . i % ex . j % ex . k % ex . value ;
}
}
return examples ;
}
2012-05-30 17:32:05 +00:00
int Framework : : theLastN ( )
2012-05-28 11:03:20 +00:00
{
2012-05-30 17:32:05 +00:00
auto fnames = getFileNames ( ) ;
if ( ! fnames . size ( ) )
return - 1 ;
range : : sort ( fnames , [ ] ( const std : : string & a , const std : : string & b )
{
return boost : : lexical_cast < int > ( fs : : basename ( a ) ) < boost : : lexical_cast < int > ( fs : : basename ( b ) ) ;
} ) ;
return boost : : lexical_cast < int > ( fs : : basename ( fnames . back ( ) ) ) ;
2012-05-28 11:03:20 +00:00
}
2012-05-30 17:32:05 +00:00
void Framework : : buildLearningSet ( )
2012-05-28 11:03:20 +00:00
{
2012-05-30 17:32:05 +00:00
vector < Example > examples = loadExamples ( ) ;
range : : sort ( examples ) ;
2012-05-28 11:03:20 +00:00
2012-05-30 17:32:05 +00:00
int startExamplesFrom = 0 ;
ofstream learningLog ( " log.txt " , std : : ios : : app ) ;
int n = theLastN ( ) + 1 ;
auto armies = learningArmies ( ) ;
auto bonuese = learningBonuses ( ) ;
for ( int i = 0 ; i < armies . size ( ) ; i + + )
2012-05-28 11:03:20 +00:00
{
2012-05-30 17:32:05 +00:00
string army = " army " + toString ( i ) ;
for ( int j = 0 ; j < bonuese . size ( ) ; j + + )
{
Bonus b = bonuese [ j ] ;
string bonusStr = " bonus " + toString ( j ) + describeBonus ( b ) ;
for ( int k = 0 ; k < 10 ; k + + )
{
int nHere = n + + ;
// if(nHere < startExamplesFrom)
// continue;
//
tlog2 < < " n= " < < nHere < < std : : endl ;
b . val = k ;
Example ex ;
ex . i = i ;
ex . j = j ;
ex . k = k ;
ex . art = generateArtWithBonus ( b ) ;
ex . dp = generateDuel ( armies [ i ] ) ;
ex . description = army + " \t " + describeBonus ( b ) + " \t " ;
if ( vstd : : contains ( examples , ex ) )
{
string msg = str ( format ( " n=%d \t army %d \t bonus %d \t result %lf \t Bonus#%s# " ) % nHere % i % j % ex . value % describeBonus ( b ) ) ;
tlog0 < < " Already present example, skipping " < < msg ;
continue ;
}
ex . value = rateArt ( ex . dp , ex . art ) ;
CSaveFile output ( " ./examples/ " + toString ( nHere ) + " .example " ) ;
output < < ex ;
time_t rawtime ;
struct tm * timeinfo ;
time ( & rawtime ) ;
timeinfo = localtime ( & rawtime ) ;
string msg = str ( format ( " n=%d \t army %d \t bonus %d \t result %lf \t Bonus#%s# \t date: %s " ) % nHere % i % j % ex . value % describeBonus ( b ) % asctime ( timeinfo ) ) ;
learningLog < < msg < < flush ;
tlog0 < < msg ;
}
}
2012-05-28 11:03:20 +00:00
}
2012-05-30 17:32:05 +00:00
tlog0 < < " Set of learning/testing examples is complete and ready! \n " ;
2012-05-28 11:03:20 +00:00
}
2012-05-30 17:32:05 +00:00
vector < ArmyDescriptor > Framework : : learningArmies ( )
2012-05-28 11:03:20 +00:00
{
2012-05-30 17:32:05 +00:00
vector < ArmyDescriptor > ret ;
2012-05-28 11:03:20 +00:00
//armia zlozona ze stworow z malymi HP-kami
ArmyDescriptor lowHP ;
lowHP [ 0 ] = CStackBasicDescriptor ( 1 , 9 ) ; //halabardier
lowHP [ 1 ] = CStackBasicDescriptor ( 14 , 20 ) ; //centaur
lowHP [ 2 ] = CStackBasicDescriptor ( 139 , 123 ) ; //chlop
lowHP [ 3 ] = CStackBasicDescriptor ( 70 , 30 ) ; //troglodyta
lowHP [ 4 ] = CStackBasicDescriptor ( 42 , 50 ) ; //imp
//armia zlozona z poteznaych stworow
ArmyDescriptor highHP ;
highHP [ 0 ] = CStackBasicDescriptor ( 13 , 17 ) ; //archaniol
highHP [ 1 ] = CStackBasicDescriptor ( 132 , 8 ) ; //azure dragon
highHP [ 2 ] = CStackBasicDescriptor ( 133 , 10 ) ; //crystal dragon
highHP [ 3 ] = CStackBasicDescriptor ( 83 , 22 ) ; //black dragon
//armia zlozona z tygodniowego przyrostu w zamku
auto & castleTown = VLC - > townh - > towns [ 0 ] ;
ArmyDescriptor castleNormal ;
for ( int i = 0 ; i < 7 ; i + + )
{
auto & cre = VLC - > creh - > creatures [ castleTown . basicCreatures [ i ] ] ;
castleNormal [ i ] = CStackBasicDescriptor ( cre . get ( ) , cre - > growth ) ;
}
castleNormal [ 5 ] . type = VLC - > creh - > creatures [ 52 ] ; //replace cavaliers with Efreeti -> stupid ai sometimes blocks with two-hex walkers
//armia zlozona z tygodniowego ulepszonego przyrostu w ramparcie
auto & rampartTown = VLC - > townh - > towns [ 1 ] ;
ArmyDescriptor rampartUpgraded ;
for ( int i = 0 ; i < 7 ; i + + )
{
auto & cre = VLC - > creh - > creatures [ rampartTown . upgradedCreatures [ i ] ] ;
rampartUpgraded [ i ] = CStackBasicDescriptor ( cre . get ( ) , cre - > growth ) ;
}
rampartUpgraded [ 5 ] . type = VLC - > creh - > creatures [ 52 ] ; //replace unicorn with Efreeti -> stupid ai sometimes blocks with two-hex walkers
//armia zlozona z samych strzelcow
ArmyDescriptor shooters ;
shooters [ 0 ] = CStackBasicDescriptor ( 35 , 17 ) ; //arcymag
shooters [ 1 ] = CStackBasicDescriptor ( 41 , 1 ) ; //titan
shooters [ 2 ] = CStackBasicDescriptor ( 3 , 70 ) ; //kusznik
shooters [ 3 ] = CStackBasicDescriptor ( 89 , 50 ) ; //ulepszony ork
ret . push_back ( lowHP ) ;
ret . push_back ( highHP ) ;
ret . push_back ( castleNormal ) ;
ret . push_back ( rampartUpgraded ) ;
ret . push_back ( shooters ) ;
return ret ;
}
2012-05-30 17:32:05 +00:00
vector < Bonus > Framework : : learningBonuses ( )
2012-05-28 11:03:20 +00:00
{
2012-05-30 17:32:05 +00:00
vector < Bonus > ret ;
2012-05-28 11:03:20 +00:00
Bonus b ;
b . type = Bonus : : PRIMARY_SKILL ;
b . subtype = PrimarySkill : : ATTACK ;
ret . push_back ( b ) ;
b . subtype = PrimarySkill : : DEFENSE ;
ret . push_back ( b ) ;
b . type = Bonus : : STACK_HEALTH ;
b . subtype = 0 ;
ret . push_back ( b ) ;
b . type = Bonus : : STACKS_SPEED ;
ret . push_back ( b ) ;
b . type = Bonus : : BLOCKS_RETALIATION ;
ret . push_back ( b ) ;
b . type = Bonus : : ADDITIONAL_RETALIATION ;
ret . push_back ( b ) ;
b . type = Bonus : : ADDITIONAL_ATTACK ;
ret . push_back ( b ) ;
b . type = Bonus : : CREATURE_DAMAGE ;
ret . push_back ( b ) ;
b . type = Bonus : : ALWAYS_MAXIMUM_DAMAGE ;
ret . push_back ( b ) ;
b . type = Bonus : : NO_DISTANCE_PENALTY ;
ret . push_back ( b ) ;
return ret ;
}
2012-05-30 17:32:05 +00:00
double Framework : : rateArt ( const DuelParameters dp , CArtifactInstance * inst )
2011-12-09 21:37:32 +00:00
{
2012-05-30 17:32:05 +00:00
TArtSet setL , setR ;
setL [ inst - > artType - > possibleSlots [ 0 ] ] = inst ;
2011-12-09 21:37:32 +00:00
2012-05-30 17:32:05 +00:00
double resultLR = cmpArtSets ( dp , setL , setR ) ,
resultRL = cmpArtSets ( dp , setR , setL ) ,
resultsBase = cmpArtSets ( dp , TArtSet ( ) , TArtSet ( ) ) ;
2011-12-09 21:37:32 +00:00
2012-05-30 17:32:05 +00:00
//lewa strona z art 0.9
//bez artefaktow -0.41
//prawa strona z art. -0.926
double LRgain = resultLR - resultsBase ,
RLgain = resultsBase - resultRL ;
2012-05-31 20:13:30 +00:00
return ( LRgain + RLgain ) / 4 ;
2011-11-10 09:21:27 +00:00
}
2012-05-30 17:32:05 +00:00
double Framework : : cmpArtSets ( DuelParameters dp , TArtSet setL , TArtSet setR )
2011-11-10 09:21:27 +00:00
{
2012-05-30 17:32:05 +00:00
dp . sides [ 0 ] . artifacts = setL ;
dp . sides [ 1 ] . artifacts = setR ;
2011-11-10 09:21:27 +00:00
2012-05-30 17:32:05 +00:00
auto battleOutcome = playBattle ( dp ) ;
return battleOutcome ;
2011-10-20 23:29:32 +00:00
}
2011-09-27 21:54:40 +00:00
2012-05-30 17:32:05 +00:00
double Framework : : playBattle ( const DuelParameters & dp )
2012-05-12 17:59:32 +00:00
{
2012-05-28 11:03:20 +00:00
string battleFileName = " pliczek.ssnb " ;
2012-05-12 17:59:32 +00:00
{
2012-05-28 11:03:20 +00:00
CSaveFile out ( battleFileName ) ;
2012-05-12 17:59:32 +00:00
out < < dp ;
}
2012-05-28 11:03:20 +00:00
std : : string serverCommand = servername + " " + addQuotesIfNeeded ( battleFileName ) + " " + addQuotesIfNeeded ( leftAI ) + " " + addQuotesIfNeeded ( rightAI ) + " " + addQuotesIfNeeded ( results ) + " " + addQuotesIfNeeded ( logsDir ) + " " + ( withVisualization ? " v " : " " ) ;
2012-05-12 17:59:32 +00:00
std : : string runnerCommand = runnername + " " + addQuotesIfNeeded ( logsDir ) ;
std : : cout < < " Server command: " < < serverCommand < < std : : endl < < " Runner command: " < < runnerCommand < < std : : endl ;
int code = 0 ;
boost : : thread t ( [ & ]
{
code = std : : system ( serverCommand . c_str ( ) ) ;
} ) ;
runCommand ( runnerCommand , " first_runner " , logsDir ) ;
runCommand ( runnerCommand , " second_runner " , logsDir ) ;
runCommand ( runnerCommand , " third_runner " , logsDir ) ;
if ( withVisualization )
{
//boost::this_thread::sleep(boost::posix_time::millisec(500)); //FIXME
boost : : thread tttt ( boost : : bind ( std : : system , " VCMI_Client.exe -battle " ) ) ;
}
//boost::this_thread::sleep(boost::posix_time::seconds(5));
t . join ( ) ;
return code / 1000000.0 ;
}
2012-05-30 17:32:05 +00:00
void Framework : : runCommand ( const std : : string & command , const std : : string & name , const std : : string & logsDir /*= ""*/ )
{
static std : : string commands [ 100000 ] ;
static int i = 0 ;
std : : string & cmd = commands [ i + + ] ;
if ( logsDir . size ( ) & & name . size ( ) )
{
std : : string directionLogs = logsDir + " / " + name + " .txt " ;
cmd = command + " > " + addQuotesIfNeeded ( directionLogs ) ;
}
else
cmd = command ;
2012-05-13 14:51:13 +00:00
2012-05-30 17:32:05 +00:00
boost : : thread tt ( boost : : bind ( std : : system , cmd . c_str ( ) ) ) ;
}
2012-05-13 14:51:13 +00:00
2012-05-30 17:32:05 +00:00
DuelParameters Framework : : generateDuel ( const ArmyDescriptor & ad )
2012-05-14 17:57:22 +00:00
{
2012-05-30 17:32:05 +00:00
DuelParameters dp ;
dp . bfieldType = 1 ;
dp . terType = 1 ;
2012-05-13 14:51:13 +00:00
2012-05-30 17:32:05 +00:00
auto & side = dp . sides [ 0 ] ;
side . heroId = 0 ;
side . heroPrimSkills . resize ( 4 , 0 ) ;
BOOST_FOREACH ( auto & stack , ad )
{
side . stacks [ stack . first ] = DuelParameters : : SideSettings : : StackSettings ( stack . second . type - > idNumber , stack . second . count ) ;
}
dp . sides [ 1 ] = side ;
dp . sides [ 1 ] . heroId = 1 ;
return dp ;
2012-05-13 14:51:13 +00:00
}
2012-05-30 17:32:05 +00:00
CArtifactInstance * Framework : : generateArtWithBonus ( const Bonus & b )
2012-05-12 17:59:32 +00:00
{
2012-05-13 14:51:13 +00:00
std : : vector < CArtifactInstance * > ret ;
2012-05-28 11:03:20 +00:00
static CArtifact * nowy = NULL ;
2012-05-30 17:32:05 +00:00
2012-05-28 11:03:20 +00:00
if ( ! nowy )
{
nowy = new CArtifact ( ) ;
nowy - > description = " Cudowny miecz Towa gwarantuje zwyciestwo " ;
nowy - > name = " Cudowny miecz " ;
nowy - > constituentOf = nowy - > constituents = NULL ;
nowy - > possibleSlots . push_back ( Arts : : LEFT_HAND ) ;
}
2012-05-12 17:59:32 +00:00
2012-05-28 11:03:20 +00:00
CArtifactInstance * artinst = new CArtifactInstance ( nowy ) ;
artinst - > addNewBonus ( new Bonus ( b ) ) ;
return artinst ;
}
2012-05-12 17:59:32 +00:00
2012-05-30 17:32:05 +00:00
class SSN
2012-05-28 11:03:20 +00:00
{
2012-05-30 17:32:05 +00:00
FANN : : neural_net net ;
2012-05-30 21:00:23 +00:00
struct ParameterSet ;
2012-05-12 17:59:32 +00:00
2012-05-30 21:00:23 +00:00
void init ( const ParameterSet & params ) ;
FANN : : training_data * getTrainingData ( const std : : vector < Example > & input ) ;
2012-05-30 17:32:05 +00:00
static int ANNCallback ( FANN : : neural_net & net , FANN : : training_data & train , unsigned int max_epochs , unsigned int epochs_between_reports , float desired_error , unsigned int epochs , void * user_data ) ;
static double * genSSNinput ( const DuelParameters : : SideSettings & dp , CArtifactInstance * art , si32 bfieldType , si32 terType ) ;
static const unsigned int num_input = 18 ;
public :
2012-05-30 21:00:23 +00:00
struct ParameterSet
{
unsigned int neuronsInHidden ;
double actSteepHidden , actSteepnessOutput ;
FANN : : activation_function_enum hiddenActFun , outActFun ;
} ;
2012-05-30 17:32:05 +00:00
SSN ( ) ;
2012-05-31 20:42:10 +00:00
SSN ( string filename ) ;
2012-05-30 17:32:05 +00:00
~ SSN ( ) ;
2012-05-21 20:11:02 +00:00
2012-05-30 21:00:23 +00:00
//returns mse after learning
double learn ( const std : : vector < Example > & input , const ParameterSet & params ) ;
2012-05-31 20:42:10 +00:00
double learn ( bool adjustParams = false ) ;
2012-05-30 21:00:23 +00:00
2012-05-31 20:42:10 +00:00
SSN : : ParameterSet getBestParams ( vector < Example > & trainingSet ) ;
SSN : : ParameterSet getBestParams ( ) ;
2012-05-30 21:00:23 +00:00
double test ( const std : : vector < Example > & input )
{
auto td = getTrainingData ( input ) ;
return net . test_data ( * td ) ;
delete td ;
}
2012-05-30 17:32:05 +00:00
double run ( const DuelParameters & dp , CArtifactInstance * inst ) ;
void save ( const std : : string & filename ) ;
2012-05-31 20:42:10 +00:00
void load ( const std : : string & filename ) ;
2012-05-30 17:32:05 +00:00
} ;
SSN : : SSN ( )
2012-05-30 21:00:23 +00:00
{ }
2012-05-12 17:59:32 +00:00
2012-05-31 20:42:10 +00:00
SSN : : SSN ( string filename )
{
load ( filename ) ;
}
2012-05-30 21:00:23 +00:00
void SSN : : init ( const ParameterSet & params )
2012-05-13 14:51:13 +00:00
{
2012-05-30 17:32:05 +00:00
const float learning_rate = 0.7f ;
const unsigned int num_layers = 3 ;
const unsigned int num_output = 1 ;
2012-05-30 21:00:23 +00:00
const float desired_error = 0.01f ;
const unsigned int max_iterations = 30000 ;
2012-05-30 17:32:05 +00:00
const unsigned int iterations_between_reports = 1000 ;
2012-05-12 17:59:32 +00:00
2012-05-30 21:00:23 +00:00
net . create_standard ( num_layers , num_input , params . neuronsInHidden , num_output ) ;
2012-05-21 20:11:02 +00:00
2012-05-30 17:32:05 +00:00
net . set_learning_rate ( learning_rate ) ;
2012-05-28 11:03:20 +00:00
2012-05-30 21:00:23 +00:00
net . set_activation_steepness_hidden ( params . actSteepHidden ) ;
net . set_activation_steepness_output ( params . actSteepnessOutput ) ;
2012-05-28 11:03:20 +00:00
2012-05-30 21:00:23 +00:00
net . set_activation_function_hidden ( params . hiddenActFun ) ;
net . set_activation_function_output ( params . outActFun ) ;
2012-05-28 11:03:20 +00:00
2012-05-30 17:32:05 +00:00
net . randomize_weights ( 0.0 , 1.0 ) ;
2012-05-13 14:51:13 +00:00
}
2012-05-12 17:59:32 +00:00
2012-05-30 17:32:05 +00:00
double SSN : : run ( const DuelParameters & dp , CArtifactInstance * inst )
{
double * input = genSSNinput ( dp . sides [ 0 ] , inst , dp . bfieldType , dp . terType ) ;
double * out = net . run ( input ) ;
double ret = * out ;
2012-05-31 20:42:10 +00:00
//free(out);
2012-05-30 17:32:05 +00:00
return ret ;
}
2012-05-14 17:57:22 +00:00
2012-05-30 17:32:05 +00:00
int SSN : : ANNCallback ( FANN : : neural_net & net , FANN : : training_data & train , unsigned int max_epochs , unsigned int epochs_between_reports , float desired_error , unsigned int epochs , void * user_data )
{
//cout << "Epochs " << setw(8) << epochs << ". "
// << "Current Error: " << left << net.get_MSE() << right << endl;
return 0 ;
}
2012-05-14 17:57:22 +00:00
2012-05-30 21:00:23 +00:00
double SSN : : learn ( const std : : vector < Example > & input , const ParameterSet & params )
2012-05-30 17:32:05 +00:00
{
2012-05-30 21:00:23 +00:00
init ( params ) ;
2012-05-30 17:32:05 +00:00
//FIXME - sypie przy destrukcji
//FANN::training_data td;
2012-05-30 21:00:23 +00:00
FANN : : training_data * td = getTrainingData ( input ) ;
2012-05-30 17:32:05 +00:00
net . set_callback ( ANNCallback , NULL ) ;
2012-05-30 21:00:23 +00:00
net . train_on_data ( * td , 1000 , 1000 , 0.01 ) ;
2012-05-31 20:13:30 +00:00
// int exNum = 130;
//
// for(int exNum =0; exNum<input.size(); ++exNum)
// {
// double * testIn = genSSNinput(input[exNum].dp.sides[0], input[exNum].art, input[exNum].dp.bfieldType, input[exNum].dp.terType);
//
// double ans = *net.run(testIn);
// int g = 0;
// }
2012-05-30 21:00:23 +00:00
return net . test_data ( * td ) ;
2012-05-30 17:32:05 +00:00
}
2012-05-31 20:42:10 +00:00
double SSN : : learn ( bool adjustParams /* = false*/ )
{
cout < < " Loading examples... \n " ;
auto trainingSet = Framework : : loadExamples ( false ) ;
cout < < " Looking for best learning parameters... \n " ;
auto params = adjustParams ? getBestParams ( trainingSet ) : getBestParams ( ) ;
cout < < " Learning... \n " ;
//saving of best network
double finalLmse = learn ( trainingSet , params ) ;
cout < < " Learning done, LMSE= " < < finalLmse < < endl ;
save ( " last_network.net " ) ;
return finalLmse ;
}
2012-05-30 17:32:05 +00:00
double * SSN : : genSSNinput ( const DuelParameters : : SideSettings & dp , CArtifactInstance * art , si32 bfieldType , si32 terType )
2012-05-14 17:57:22 +00:00
{
double * ret = new double [ num_input ] ;
double * cur = ret ;
//general description
2012-05-28 20:31:32 +00:00
* ( cur + + ) = bfieldType / 30.0 ;
* ( cur + + ) = terType / 12.0 ;
2012-05-14 17:57:22 +00:00
//creature & hero description
2012-05-30 17:32:05 +00:00
2012-05-28 20:31:32 +00:00
* ( cur + + ) = dp . heroId / 200.0 ;
for ( int k = 0 ; k < 4 ; + + k )
* ( cur + + ) = dp . heroPrimSkills [ k ] / 20.0 ;
2012-05-16 20:26:34 +00:00
2012-05-28 20:31:32 +00:00
//weighted average of statistics
auto avg = [ & ] ( std : : function < int ( CCreature * ) > getter ) - > double
{
double ret = 0.0 ;
int div = 0 ;
for ( int i = 0 ; i < 7 ; + + i )
2012-05-14 17:57:22 +00:00
{
2012-05-28 20:31:32 +00:00
auto & cstack = dp . stacks [ i ] ;
if ( cstack . count > 0 )
2012-05-16 20:26:34 +00:00
{
2012-05-28 20:31:32 +00:00
ret + = getter ( VLC - > creh - > creatures [ cstack . type ] ) * cstack . count ;
div + = cstack . count ;
2012-05-16 20:26:34 +00:00
}
2012-05-28 20:31:32 +00:00
}
return ret / div ;
} ;
2012-05-16 20:26:34 +00:00
2012-05-28 20:31:32 +00:00
* ( cur + + ) = avg ( [ ] ( CCreature * c ) { return c - > attack ; } ) / 50.0 ;
* ( cur + + ) = avg ( [ ] ( CCreature * c ) { return c - > defence ; } ) / 50.0 ;
* ( cur + + ) = avg ( [ ] ( CCreature * c ) { return c - > speed ; } ) / 15.0 ;
* ( cur + + ) = avg ( [ ] ( CCreature * c ) { return c - > hitPoints ; } ) / 1000.0 ;
2012-05-14 17:57:22 +00:00
//bonus description
2012-05-21 20:11:02 +00:00
auto & blist = art - > getBonusList ( ) ;
2012-05-30 17:32:05 +00:00
2012-05-26 13:27:32 +00:00
* ( cur + + ) = blist [ 0 ] - > type / 100.0 ;
* ( cur + + ) = blist [ 0 ] - > subtype / 10.0 ;
* ( cur + + ) = blist [ 0 ] - > val / 100.0 ; ;
* ( cur + + ) = art - > Attack ( ) / 10.0 ;
* ( cur + + ) = art - > Defense ( ) / 10.0 ;
* ( cur + + ) = blist . valOfBonuses ( Selector : : type ( Bonus : : STACKS_SPEED ) ) / 5.0 ;
* ( cur + + ) = blist . valOfBonuses ( Selector : : type ( Bonus : : STACK_HEALTH ) ) / 10.0 ;
return ret ;
}
2012-05-30 17:32:05 +00:00
void SSN : : save ( const std : : string & filename )
2012-05-26 13:27:32 +00:00
{
2012-05-30 17:32:05 +00:00
net . save ( filename ) ;
2012-05-14 17:57:22 +00:00
}
2012-05-30 17:32:05 +00:00
SSN : : ~ SSN ( )
2012-05-27 11:57:15 +00:00
{
2012-05-13 14:51:13 +00:00
}
2012-05-30 21:00:23 +00:00
FANN : : training_data * SSN : : getTrainingData ( const std : : vector < Example > & input )
{
FANN : : training_data * ret = new FANN : : training_data ;
double * * inputs = new double * [ input . size ( ) ] ;
double * * outputs = new double * [ input . size ( ) ] ;
for ( int i = 0 ; i < input . size ( ) ; + + i )
{
const auto & ci = input [ i ] ;
inputs [ i ] = genSSNinput ( ci . dp . sides [ 0 ] , ci . art , ci . dp . bfieldType , ci . dp . terType ) ;
outputs [ i ] = new double ;
2012-05-31 20:32:57 +00:00
* ( outputs [ i ] ) = ci . value / 4 ;
2012-05-30 21:00:23 +00:00
}
ret - > set_train_data ( input . size ( ) , num_input , inputs , 1 , outputs ) ;
return ret ;
}
2012-05-31 20:42:10 +00:00
void SSN : : load ( const std : : string & filename )
2012-05-13 14:51:13 +00:00
{
2012-05-31 20:42:10 +00:00
net . create_from_file ( filename ) ;
cout < < " Loaded a network from file " < < filename < < endl ;
}
2012-05-30 21:00:23 +00:00
2012-05-31 20:42:10 +00:00
SSN : : ParameterSet SSN : : getBestParams ( vector < Example > & trainingSet )
{
double percentToTrain = 0.8 ;
2012-05-30 21:00:23 +00:00
std : : vector < Example > testSet ;
for ( int i = 0 , maxi = trainingSet . size ( ) * ( 1 - percentToTrain ) ; i < maxi ; + + i )
{
int ind = rand ( ) % trainingSet . size ( ) ;
testSet . push_back ( trainingSet [ ind ] ) ;
trainingSet . erase ( trainingSet . begin ( ) + ind ) ;
}
2012-05-27 11:57:15 +00:00
2012-05-30 21:00:23 +00:00
SSN : : ParameterSet bestParams ;
double besttMSE = 1e10 ;
boost : : mt19937 rng ;
boost : : uniform_01 < boost : : mt19937 > zeroone ( rng ) ;
FANN : : activation_function_enum possibleFuns [ ] = { FANN : : SIGMOID_SYMMETRIC_STEPWISE , FANN : : LINEAR ,
FANN : : SIGMOID , FANN : : SIGMOID_STEPWISE , FANN : : SIGMOID_SYMMETRIC } ;
2012-05-31 20:13:30 +00:00
2012-05-30 21:00:23 +00:00
for ( int i = 0 ; i < 5000 ; i + = 1 )
{
SSN : : ParameterSet ps ;
ps . actSteepHidden = zeroone ( ) + 0.3 ;
ps . actSteepnessOutput = zeroone ( ) + 0.3 ;
ps . neuronsInHidden = rand ( ) % 40 + 10 ;
ps . hiddenActFun = possibleFuns [ rand ( ) % ARRAY_COUNT ( possibleFuns ) ] ;
ps . outActFun = possibleFuns [ rand ( ) % ARRAY_COUNT ( possibleFuns ) ] ;
2012-05-31 20:42:10 +00:00
double lmse = learn ( trainingSet , ps ) ;
2012-05-30 21:00:23 +00:00
2012-05-31 20:42:10 +00:00
double tmse = test ( testSet ) ;
2012-05-30 21:00:23 +00:00
if ( tmse < besttMSE )
{
besttMSE = tmse ;
bestParams = ps ;
}
cout < < " hid: \t " < < i < < " lmse: \t " < < lmse < < " tmse: \t " < < tmse < < std : : endl ;
}
2012-05-31 20:42:10 +00:00
return bestParams ;
}
SSN : : ParameterSet SSN : : getBestParams ( )
{
// bestParams.actSteepHidden = 0.346;
// bestParams.actSteepnessOutput = 0.449;
// bestParams.hiddenActFun = FANN::SIGMOID_SYMMETRIC;
// bestParams.outActFun = FANN::SIGMOID_SYMMETRIC;
// bestParams.neuronsInHidden = 23;
SSN : : ParameterSet params ;
2012-05-31 22:16:03 +00:00
params . actSteepHidden = 0.784 ;
params . actSteepnessOutput = 0.713 ;
2012-05-31 20:42:10 +00:00
params . hiddenActFun = FANN : : SIGMOID_STEPWISE ;
2012-05-31 22:16:03 +00:00
params . outActFun = FANN : : SIGMOID_SYMMETRIC_STEPWISE ;
params . neuronsInHidden = 14 ;
2012-05-31 20:42:10 +00:00
return params ;
2012-05-12 17:59:32 +00:00
}
2012-05-31 20:42:10 +00:00
struct SSN_Runner
{
unique_ptr < SSN > ssn ;
ArmyDescriptor ad ;
void printHelp ( )
{
const char * cmds [ ] = { " help - prints this info " , " create - creates a new ANN, needs to be learned then " , " load <file> - loads ANN from file " , " save <file> - saves current ANN to file " , " learn - runs learning process using examples set " , " ask <id> - evaluates given art " , " exit - closes application " ,
" army clear - removes current army information " , " army add <id> <count> - adds creature to army " , " army remove <pos> - removes stack from position " ,
" army print - prints current army state " , " army random - generates random army " } ;
cout < < " Available commands: \n " ;
BOOST_FOREACH ( auto cmd , cmds )
cout < < " \t " < < cmd < < endl ;
}
int run ( )
{
cout < < " Welcome to the ANN interactive mode! \n " ;
printHelp ( ) ;
while ( 1 )
{
try
{
cout < < " Please enter your command and press return. \n > " ;
stringstream ss ;
string input ;
getline ( cin , input ) ;
ss . str ( input ) ;
string command , secondWord ;
ss > > command > > secondWord ;
if ( command = = " exit " )
{
cout < < " Ending... \n " ;
exit ( 0 ) ;
}
else if ( command = = " load " )
{
if ( secondWord . empty ( ) )
secondWord = " last_network.net " ;
ssn = unique_ptr < SSN > ( new SSN ( secondWord ) ) ;
}
else if ( command = = " create " )
{
ssn = unique_ptr < SSN > ( new SSN ( ) ) ;
cout < < " Network successfully created. It still needs to be learnt. \n " ;
}
else if ( command = = " help " )
{
printHelp ( ) ;
}
else if ( command = = " army " & & secondWord . size ( ) )
{
if ( secondWord = = " clear " )
{
ad . clear ( ) ;
cout < < " Army is now empty. \n " ;
}
if ( secondWord = = " print " )
{
cout < < " Army contains " < < ad . size ( ) < < " creatures. \n " ;
BOOST_FOREACH ( auto & itr , ad )
{
cout < < itr . first < < " => " < < itr . second . count < < " of " < < itr . second . type - > namePl < < endl ;
}
}
if ( secondWord = = " erase " )
{
int slot ;
ss > > slot ;
if ( ad . find ( slot ) ! = ad . end ( ) )
{
ad . erase ( slot ) ;
cout < < " Slot " < < slot < < " successfully erased. \n " ;
}
}
if ( secondWord = = " add " )
{
int id , count ;
ss > > id > > count ;
int i = 0 ;
if ( id < 0 | | id > = 118 )
{
throw std : : runtime_error ( " Id has to be in <0,118> " ) ;
}
if ( count < = 0 )
{
throw std : : runtime_error ( " Count has to be > 0 " ) ;
}
while ( ad . find ( i + + ) ! = ad . end ( ) ) ;
if ( i > = ARMY_SIZE )
{
tlog1 < < " Cannot add stack, army is full! \n " ;
}
else
{
ad [ i ] = CStackBasicDescriptor ( id , count ) ;
tlog0 < < " Creature successfully added to slot " < < i < < endl ; ;
}
}
if ( secondWord = = " random " )
{
srand ( time ( 0 ) ) ;
ad . clear ( ) ;
int stacks = rand ( ) % 7 + 1 ;
for ( int i = 0 ; i < stacks ; i + + )
{
CCreature * c = VLC - > creh - > creatures [ rand ( ) % 118 ] ;
ad [ i ] = CStackBasicDescriptor ( c , c - > growth ) ;
}
cout < < " Generated random army of " < < stacks < < " creatures. \n " ;
}
}
else if ( ! ssn )
{
cout < < " Error: you need to create or load ANN from file first! \n " ;
continue ;
}
else if ( command = = " learn " )
{
ssn - > learn ( ) ;
}
else if ( command = = " save " )
{
ssn - > save ( secondWord ) ;
}
else if ( command = = " ask " )
{
2012-05-31 22:16:03 +00:00
if ( ad . empty ( ) )
{
throw std : : runtime_error ( " Army needs to be set first! " ) ;
}
2012-05-31 20:42:10 +00:00
int artid = boost : : lexical_cast < int > ( secondWord ) ;
CArtifact * art = VLC - > arth - > artifacts . at ( artid ) ;
DuelParameters dp = Framework : : generateDuel ( ad ) ;
CArtifactInstance * artInst = new CArtifactInstance ( art ) ;
auto bonuses = art - > getBonuses ( [ ] ( const Bonus * ) { return true ; } ) ;
if ( ! bonuses - > size ( ) )
{
tlog1 < < " This artifact deosn't provide any bonuses. Please pick another one. " ;
}
else
{
BOOST_FOREACH ( auto b , * bonuses )
artInst - > addNewBonus ( new Bonus ( * b ) ) ;
auto val = ssn - > run ( dp , artInst ) ;
cout < < " ANN rates " < < art - > Name ( ) < < " to value = " < < val < < endl ;
}
}
else
tlog1 < < " Unknown command \" " < < command < < " \" ! \n " ;
}
catch ( std : : exception & e )
{
tlog1 < < " Encountered error: " < < e . what ( ) < < endl ;
}
catch ( . . . )
{
tlog1 < < " Encountered unknown error! " < < endl ;
}
}
}
} ;
2011-10-23 13:53:37 +00:00
int main ( int argc , char * * argv )
2011-09-27 21:54:40 +00:00
{
2011-10-20 23:29:32 +00:00
std : : cout < < " VCMI Odpalarka \n My path: " < < argv [ 0 ] < < std : : endl ;
po : : options_description opts ( " Allowed options " ) ;
opts . add_options ( )
2011-11-10 09:21:27 +00:00
( " help,h " , " Display help and exit " )
( " aiLeft,l " , po : : value < std : : string > ( ) - > default_value ( " StupidAI " ) , " Left AI path " )
( " aiRight,r " , po : : value < std : : string > ( ) - > default_value ( " StupidAI " ) , " Right AI path " )
2012-05-12 17:59:32 +00:00
( " battle,b " , po : : value < std : : string > ( ) - > default_value ( " pliczek.ssnb " ) , " Duel file path " )
2011-11-14 17:52:09 +00:00
( " resultsOut,o " , po : : value < std : : string > ( ) - > default_value ( " ./results.txt " ) , " Output file when results will be appended " )
2011-11-10 09:21:27 +00:00
( " logsDir,d " , po : : value < std : : string > ( ) - > default_value ( " . " ) , " Directory where log files will be created " )
2011-10-20 23:29:32 +00:00
( " visualization,v " , " Runs a client to display a visualization of battle " ) ;
2011-11-10 09:21:27 +00:00
try
2011-10-20 23:29:32 +00:00
{
2011-11-10 09:21:27 +00:00
po : : variables_map vm ;
po : : store ( po : : parse_command_line ( argc , argv , opts ) , vm ) ;
po : : notify ( vm ) ;
if ( vm . count ( " help " ) )
2011-10-20 23:29:32 +00:00
{
2011-11-10 09:21:27 +00:00
opts . print ( std : : cout ) ;
prog_help ( ) ;
return 0 ;
2011-10-20 23:29:32 +00:00
}
2011-11-10 09:21:27 +00:00
leftAI = vm [ " aiLeft " ] . as < std : : string > ( ) ;
rightAI = vm [ " aiRight " ] . as < std : : string > ( ) ;
battle = vm [ " battle " ] . as < std : : string > ( ) ;
results = vm [ " resultsOut " ] . as < std : : string > ( ) ;
logsDir = vm [ " logsDir " ] . as < std : : string > ( ) ;
withVisualization = vm . count ( " visualization " ) ;
2011-10-20 23:29:32 +00:00
}
2011-11-10 09:21:27 +00:00
catch ( std : : exception & e )
2011-10-20 23:29:32 +00:00
{
2011-11-10 09:21:27 +00:00
std : : cerr < < " Failure during parsing command-line options: \n " < < e . what ( ) < < std : : endl ;
exit ( 1 ) ;
2011-10-20 23:29:32 +00:00
}
2011-11-10 09:21:27 +00:00
std : : cout < < " Config: \n " < < leftAI < < " vs " < < rightAI < < " on " < < battle < < std : : endl ;
if ( leftAI . empty ( ) | | rightAI . empty ( ) | | battle . empty ( ) )
2011-10-20 23:29:32 +00:00
{
2011-11-10 09:21:27 +00:00
std : : cerr < < " I wasn't able to retreive names of AI or battles. Ending. \n " ;
return 1 ;
2011-10-20 23:29:32 +00:00
}
2012-05-12 17:59:32 +00:00
runnername =
2011-10-01 15:07:07 +00:00
# ifdef _WIN32
" VCMI_BattleAiHost.exe "
# else
" ./vcmirunner "
# endif
;
2012-05-12 17:59:32 +00:00
servername =
2011-10-01 15:07:07 +00:00
# ifdef _WIN32
" VCMI_server.exe "
# else
" ./vcmiserver "
# endif
2011-10-06 20:55:19 +00:00
;
2012-05-12 17:59:32 +00:00
VLC = new LibClasses ( ) ;
VLC - > init ( ) ;
2011-11-10 09:21:27 +00:00
2012-05-31 20:42:10 +00:00
SSN_Runner runner ;
runner . run ( ) ;
2011-09-29 21:53:39 +00:00
2011-09-27 21:54:40 +00:00
return EXIT_SUCCESS ;
2012-05-16 20:26:34 +00:00
}