1
0
mirror of https://github.com/vcmi/vcmi.git synced 2024-11-28 08:48:48 +02:00

BattleAI: optional simulation depth

This commit is contained in:
Andrii Danylchenko 2024-08-21 22:18:26 +03:00
parent 22de88ad68
commit ff8a745a50
7 changed files with 74 additions and 22 deletions

View File

@ -23,6 +23,7 @@
#include "../../lib/battle/BattleAction.h"
#include "../../lib/battle/BattleStateInfoForRetreat.h"
#include "../../lib/battle/CObstacleInstance.h"
#include "../../lib/StartInfo.h"
#include "../../lib/CStack.h" // TODO: remove
// Eventually only IBattleInfoCallback and battle::Unit should be used,
// CUnitState should be private and CStack should be removed completely
@ -122,6 +123,11 @@ static float getStrengthRatio(std::shared_ptr<CBattleInfoCallback> cb, BattleSid
return enemy == 0 ? 1.0f : static_cast<float>(our) / enemy;
}
int getSimulationTurnsCount(const StartInfo * startInfo)
{
return startInfo->difficulty < 4 ? 2 : 10;
}
void CBattleAI::activeStack(const BattleID & battleID, const CStack * stack )
{
LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName());
@ -154,7 +160,10 @@ void CBattleAI::activeStack(const BattleID & battleID, const CStack * stack )
logAi->trace("Build evaluator and targets");
#endif
BattleEvaluator evaluator(env, cb, stack, playerID, battleID, side, getStrengthRatio(cb->getBattle(battleID), side));
BattleEvaluator evaluator(
env, cb, stack, playerID, battleID, side,
getStrengthRatio(cb->getBattle(battleID), side),
getSimulationTurnsCount(env->game()->getStartInfo()));
result = evaluator.selectStackAction(stack);

View File

@ -49,6 +49,45 @@ SpellTypes spellType(const CSpell * spell)
return SpellTypes::OTHER;
}
BattleEvaluator::BattleEvaluator(
std::shared_ptr<Environment> env,
std::shared_ptr<CBattleCallback> cb,
const battle::Unit * activeStack,
PlayerColor playerID,
BattleID battleID,
BattleSide side,
float strengthRatio,
int simulationTurnsCount)
:scoreEvaluator(cb->getBattle(battleID), env, strengthRatio, simulationTurnsCount),
cachedAttack(), playerID(playerID), side(side), env(env),
cb(cb), strengthRatio(strengthRatio), battleID(battleID), simulationTurnsCount(simulationTurnsCount)
{
hb = std::make_shared<HypotheticBattle>(env.get(), cb->getBattle(battleID));
damageCache.buildDamageCache(hb, side);
targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
}
BattleEvaluator::BattleEvaluator(
std::shared_ptr<Environment> env,
std::shared_ptr<CBattleCallback> cb,
std::shared_ptr<HypotheticBattle> hb,
DamageCache & damageCache,
const battle::Unit * activeStack,
PlayerColor playerID,
BattleID battleID,
BattleSide side,
float strengthRatio,
int simulationTurnsCount)
:scoreEvaluator(cb->getBattle(battleID), env, strengthRatio, simulationTurnsCount),
cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), hb(hb),
damageCache(damageCache), strengthRatio(strengthRatio), battleID(battleID), simulationTurnsCount(simulationTurnsCount)
{
targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
}
std::vector<BattleHex> BattleEvaluator::getBrokenWallMoatHexes() const
{
std::vector<BattleHex> result;
@ -618,7 +657,7 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
#endif
PotentialTargets innerTargets(activeStack, innerCache, state);
BattleExchangeEvaluator innerEvaluator(state, env, strengthRatio);
BattleExchangeEvaluator innerEvaluator(state, env, strengthRatio, simulationTurnsCount);
if(!innerTargets.possibleAttacks.empty())
{

View File

@ -37,6 +37,7 @@ class BattleEvaluator
float cachedScore;
DamageCache damageCache;
float strengthRatio;
int simulationTurnsCount;
public:
BattleAction selectStackAction(const CStack * stack);
@ -56,15 +57,8 @@ public:
PlayerColor playerID,
BattleID battleID,
BattleSide side,
float strengthRatio)
:scoreEvaluator(cb->getBattle(battleID), env, strengthRatio), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), strengthRatio(strengthRatio), battleID(battleID)
{
hb = std::make_shared<HypotheticBattle>(env.get(), cb->getBattle(battleID));
damageCache.buildDamageCache(hb, side);
targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
}
float strengthRatio,
int simulationTurnsCount);
BattleEvaluator(
std::shared_ptr<Environment> env,
@ -75,10 +69,6 @@ public:
PlayerColor playerID,
BattleID battleID,
BattleSide side,
float strengthRatio)
:scoreEvaluator(cb->getBattle(battleID), env, strengthRatio), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), hb(hb), damageCache(damageCache), strengthRatio(strengthRatio), battleID(battleID)
{
targets = std::make_unique<PotentialTargets>(activeStack, damageCache, hb);
cachedScore = EvaluationResult::INEFFECTIVE_SCORE;
}
float strengthRatio,
int simulationTurnsCount);
};

View File

@ -655,11 +655,14 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
});
bool canUseAp = true;
const int totalTurnsCount = 10;
std::set<uint32_t> blockedShooters;
for(int exchangeTurn = 0; exchangeTurn < totalTurnsCount; exchangeTurn++)
int totalTurnsCount = simulationTurnsCount >= turn + turnOrder.size()
? simulationTurnsCount
: turn + turnOrder.size();
for(int exchangeTurn = 0; exchangeTurn < simulationTurnsCount; exchangeTurn++)
{
bool isMovingTurm = exchangeTurn < turn;
int queueTurn = exchangeTurn >= exchangeUnits.units.size()
@ -826,6 +829,14 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
auto score = v.getScore();
if(simulationTurnsCount < totalTurnsCount)
{
float scalingRatio = simulationTurnsCount / static_cast<float>(totalTurnsCount);
score.enemyDamageReduce *= scalingRatio;
score.ourDamageReduce *= scalingRatio;
}
if(turn > 0)
{
auto turnMultiplier = 1 - std::min(0.2, 0.05 * turn);

View File

@ -132,6 +132,7 @@ private:
std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
std::vector<battle::Units> turnOrder;
float negativeEffectMultiplier;
int simulationTurnsCount;
float scoreValue(const BattleScore & score) const;
@ -149,7 +150,8 @@ public:
BattleExchangeEvaluator(
std::shared_ptr<CBattleInfoCallback> cb,
std::shared_ptr<Environment> env,
float strengthRatio): cb(cb), env(env) {
float strengthRatio,
int simulationTurnsCount): cb(cb), env(env), simulationTurnsCount(simulationTurnsCount){
negativeEffectMultiplier = strengthRatio >= 1 ? 1 : strengthRatio * strengthRatio;
}

View File

@ -56,7 +56,7 @@ public:
// //various
virtual int getDate(Date mode=Date::DAY) const = 0; //mode=0 - total days in game, mode=1 - day of week, mode=2 - current week, mode=3 - current month
// const StartInfo * getStartInfo(bool beforeRandomization = false)const;
virtual const StartInfo * getStartInfo(bool beforeRandomization = false) const = 0;
virtual bool isAllowed(SpellID id) const = 0;
virtual bool isAllowed(ArtifactID id) const = 0;
virtual bool isAllowed(SecondarySkill id) const = 0;
@ -143,7 +143,7 @@ protected:
public:
//various
int getDate(Date mode=Date::DAY)const override; //mode=0 - total days in game, mode=1 - day of week, mode=2 - current week, mode=3 - current month
virtual const StartInfo * getStartInfo(bool beforeRandomization = false)const;
const StartInfo * getStartInfo(bool beforeRandomization = false) const override;
bool isAllowed(SpellID id) const override;
bool isAllowed(ArtifactID id) const override;
bool isAllowed(SecondarySkill id) const override;

View File

@ -17,6 +17,7 @@ class IGameInfoCallbackMock : public IGameInfoCallback
public:
//various
MOCK_CONST_METHOD1(getDate, int(Date));
MOCK_CONST_METHOD1(getStartInfo, const StartInfo *(bool));
MOCK_CONST_METHOD1(isAllowed, bool(SpellID));
MOCK_CONST_METHOD1(isAllowed, bool(ArtifactID));