1
0
mirror of https://github.com/vcmi/vcmi.git synced 2025-07-15 01:24:45 +02:00

BattleAI: bigger reachability map

This commit is contained in:
Andrii Danylchenko
2023-10-13 22:04:26 +03:00
parent 9eb9404f28
commit 870fbd50e3
4 changed files with 115 additions and 50 deletions

View File

@ -157,13 +157,9 @@ float AttackPossibility::calculateDamageReduce(
auto enemyDamageBeforeAttack = damageCache.getOriginalDamage(defender, attackerUnitForMeasurement, state); auto enemyDamageBeforeAttack = damageCache.getOriginalDamage(defender, attackerUnitForMeasurement, state);
auto enemiesKilled = damageDealt / maxHealth + (damageDealt % maxHealth >= defender->getFirstHPleft() ? 1 : 0); auto enemiesKilled = damageDealt / maxHealth + (damageDealt % maxHealth >= defender->getFirstHPleft() ? 1 : 0);
auto damagePerEnemy = enemyDamageBeforeAttack / (double)defender->getCount(); auto damagePerEnemy = enemyDamageBeforeAttack / (double)defender->getCount();
auto lastUnitKillValue = (damageDealt % maxHealth) / (double)maxHealth;;
// lets use cached maxHealth here instead of getAvailableHealth return damagePerEnemy * (enemiesKilled + lastUnitKillValue * HEALTH_BOUNTY);
auto firstUnitHpLeft = (availableHealth - damageDealt) % maxHealth;
auto firstUnitHealthRatio = firstUnitHpLeft == 0 ? 1 : static_cast<float>(firstUnitHpLeft) / maxHealth;
auto firstUnitKillValue = (1 - firstUnitHealthRatio) * (1 - firstUnitHealthRatio);
return damagePerEnemy * (enemiesKilled + firstUnitKillValue * HEALTH_BOUNTY);
} }
int64_t AttackPossibility::evaluateBlockedShootersDmg( int64_t AttackPossibility::evaluateBlockedShootersDmg(

View File

@ -149,7 +149,7 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
bestAttack.attack.attacker->speed(0, true), bestAttack.attack.attacker->speed(0, true),
bestAttack.defenderDamageReduce, bestAttack.defenderDamageReduce,
bestAttack.attackerDamageReduce, bestAttack.attackerDamageReduce,
bestAttack.attackValue() score
); );
if (moveTarget.scorePerTurn <= score) if (moveTarget.scorePerTurn <= score)
@ -580,7 +580,7 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
} }
else else
{ {
ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, *targets, innerCache, state); ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, 0, *targets, innerCache, state);
} }
for(auto unit : allUnits) for(auto unit : allUnits)

View File

@ -116,6 +116,10 @@ float BattleExchangeVariant::trackAttack(
} }
} }
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("ap shooters blocking: %lld", ap.shootersBlockedDmg);
#endif
attackValue += ap.shootersBlockedDmg; attackValue += ap.shootersBlockedDmg;
dpsScore.enemyDamageReduce += ap.shootersBlockedDmg; dpsScore.enemyDamageReduce += ap.shootersBlockedDmg;
attacker->afterAttack(ap.attack.shooting, false); attacker->afterAttack(ap.attack.shooting, false);
@ -233,13 +237,17 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
for(auto & ap : targets.possibleAttacks) for(auto & ap : targets.possibleAttacks)
{ {
float score = evaluateExchange(ap, targets, damageCache, hbWaited); float score = evaluateExchange(ap, 0, targets, damageCache, hbWaited);
if(score > result.score) if(score > result.score)
{ {
result.score = score; result.score = score;
result.bestAttack = ap; result.bestAttack = ap;
result.wait = true; result.wait = true;
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("New high score %2f", result.score);
#endif
} }
} }
} }
@ -258,13 +266,17 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
for(auto & ap : targets.possibleAttacks) for(auto & ap : targets.possibleAttacks)
{ {
float score = evaluateExchange(ap, targets, damageCache, hb); float score = evaluateExchange(ap, 0, targets, damageCache, hb);
if(score >= result.score) if(score > result.score || score == result.score && result.wait)
{ {
result.score = score; result.score = score;
result.bestAttack = ap; result.bestAttack = ap;
result.wait = false; result.wait = false;
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("New high score %2f", result.score);
#endif
} }
} }
@ -312,7 +324,10 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
auto hexes = closestStack->getSurroundingHexes(); auto hexes = closestStack->getSurroundingHexes();
auto enemySpeed = closestStack->speed(); auto enemySpeed = closestStack->speed();
auto speedRatio = speed / static_cast<float>(enemySpeed); auto speedRatio = speed / static_cast<float>(enemySpeed);
auto penalty = speedRatio > 1 ? 1 : speedRatio; auto multiplier = speedRatio > 1 ? 1 : speedRatio;
if(enemy->canShoot())
multiplier *= 1.5f;
for(auto hex : hexes) for(auto hex : hexes)
{ {
@ -323,7 +338,7 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
auto score = calculateExchange(attack, turnsToRich, targets, damageCache, hb); auto score = calculateExchange(attack, turnsToRich, targets, damageCache, hb);
auto scorePerTurn = BattleScore(score.enemyDamageReduce * std::sqrt(penalty / turnsToRich), score.ourDamageReduce); auto scorePerTurn = BattleScore(score.enemyDamageReduce * std::sqrt(multiplier / turnsToRich), score.ourDamageReduce);
if(result.scorePerTurn < scoreValue(scorePerTurn)) if(result.scorePerTurn < scoreValue(scorePerTurn))
{ {
@ -371,12 +386,13 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getAdjacentUnits(cons
ReachabilityData BattleExchangeEvaluator::getExchangeUnits( ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
const AttackPossibility & ap, const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets, PotentialTargets & targets,
std::shared_ptr<HypotheticBattle> hb) std::shared_ptr<HypotheticBattle> hb)
{ {
ReachabilityData result; ReachabilityData result;
auto hexes = ap.attack.defender->getHexes(); auto hexes = ap.attack.defender->getSurroundingHexes();
if(!ap.attack.shooting) hexes.push_back(ap.from); if(!ap.attack.shooting) hexes.push_back(ap.from);
@ -384,7 +400,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
for(auto hex : hexes) for(auto hex : hexes)
{ {
vstd::concatenate(allReachableUnits, reachabilityMap[hex]); vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap[hex] : getOneTurnReachableUnits(turn, hex));
} }
vstd::removeDuplicates(allReachableUnits); vstd::removeDuplicates(allReachableUnits);
@ -460,17 +476,33 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
float BattleExchangeEvaluator::evaluateExchange( float BattleExchangeEvaluator::evaluateExchange(
const AttackPossibility & ap, const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets, PotentialTargets & targets,
DamageCache & damageCache, DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb) std::shared_ptr<HypotheticBattle> hb)
{ {
BattleScore score = calculateExchange(ap, targets, damageCache, hb); if(ap.from.hex == 127)
{
logAi->trace("x");
}
BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace(
"calculateExchange score +%2f -%2fx%2f = %2f",
score.enemyDamageReduce,
score.ourDamageReduce,
getNegativeEffectMultiplier(),
scoreValue(score));
#endif
return scoreValue(score); return scoreValue(score);
} }
BattleScore BattleExchangeEvaluator::calculateExchange( BattleScore BattleExchangeEvaluator::calculateExchange(
const AttackPossibility & ap, const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets, PotentialTargets & targets,
DamageCache & damageCache, DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb) std::shared_ptr<HypotheticBattle> hb)
@ -492,8 +524,6 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive()) if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
enemyStacks.push_back(ap.attack.defender); enemyStacks.push_back(ap.attack.defender);
vstd::amin(turn, reachabilityMapByTurns.size() - 1);
ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb); ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb);
if(exchangeUnits.units.empty()) if(exchangeUnits.units.empty())
@ -511,10 +541,15 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true); bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true);
auto & attackerQueue = isOur ? ourStacks : enemyStacks; auto & attackerQueue = isOur ? ourStacks : enemyStacks;
auto u = exchangeBattle->getForUpdate(unit->unitId());
if(exchangeBattle->getForUpdate(unit->unitId())->alive() && !vstd::contains(attackerQueue, unit)) if(u->alive() && !vstd::contains(attackerQueue, unit))
{ {
attackerQueue.push_back(unit); attackerQueue.push_back(unit);
#if BATTLE_TRACE_LEVEL
logAi->trace("Exchanging: %s", u->getDescription());
#endif
} }
} }
@ -562,7 +597,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
true); true);
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), u->getDescription(), score); logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), stackWithBonuses->getDescription(), score);
#endif #endif
return score; return score;
@ -645,6 +680,13 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
// avoid blocking path for stronger stack by weaker stack // avoid blocking path for stronger stack by weaker stack
// the method checks if all stacks can be placed around enemy // the method checks if all stacks can be placed around enemy
std::map<BattleHex, battle::Units> reachabilityMap;
auto hexes = ap.attack.defender->getSurroundingHexes();
for(auto hex : hexes)
reachabilityMap[hex] = getOneTurnReachableUnits(turn, hex);
v.adjustPositions(melleeAttackers, ap, reachabilityMap); v.adjustPositions(melleeAttackers, ap, reachabilityMap);
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
@ -743,11 +785,31 @@ void BattleExchangeEvaluator::updateReachabilityMap( std::shared_ptr<HypotheticB
turnOrder.clear(); turnOrder.clear();
hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH); hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
reachabilityMap.clear();
for(int turn = 0; turn < turnOrder.size(); turn++) for(auto turn : turnOrder)
{ {
auto & turnQueue = turnOrder[turn]; for(auto u : turn)
{
if(!vstd::contains(reachabilityCache, u->unitId()))
{
reachabilityCache[u->unitId()] = hb->getReachability(u);
}
}
}
for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
{
reachabilityMap[hex] = getOneTurnReachableUnits(0, hex);
}
}
std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex)
{
std::vector<const battle::Unit *> result;
for(int i = 0; i < turnOrder.size(); i++, turn++)
{
auto & turnQueue = turnOrder[i];
HypotheticBattle turnBattle(env.get(), cb); HypotheticBattle turnBattle(env.get(), cb);
for(const battle::Unit * unit : turnQueue) for(const battle::Unit * unit : turnQueue)
@ -755,23 +817,25 @@ void BattleExchangeEvaluator::updateReachabilityMap( std::shared_ptr<HypotheticB
if(unit->isTurret()) if(unit->isTurret())
continue; continue;
auto unitSpeed = unit->speed(turn);
if(turnBattle.battleCanShoot(unit)) if(turnBattle.battleCanShoot(unit))
{ {
for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1) result.push_back(unit);
{
reachabilityMap[hex].push_back(unit);
}
continue; continue;
} }
auto unitReachability = turnBattle.getReachability(unit); auto unitSpeed = unit->speed(turn);
auto radius = unitSpeed * (turn + 1);
for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1) ReachabilityInfo unitReachability = vstd::getOrCompute(
reachabilityCache,
unit->unitId(),
[&](ReachabilityInfo & data)
{ {
bool reachable = unitReachability.distances[hex] <= unitSpeed; data = turnBattle.getReachability(unit);
});
bool reachable = unitReachability.distances[hex] <= radius;
if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK) if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
{ {
@ -781,7 +845,7 @@ void BattleExchangeEvaluator::updateReachabilityMap( std::shared_ptr<HypotheticB
{ {
for(BattleHex neighbor : hex.neighbouringTiles()) for(BattleHex neighbor : hex.neighbouringTiles())
{ {
reachable = unitReachability.distances[neighbor] <= unitSpeed; reachable = unitReachability.distances[neighbor] <= radius;
if(reachable) break; if(reachable) break;
} }
@ -790,11 +854,12 @@ void BattleExchangeEvaluator::updateReachabilityMap( std::shared_ptr<HypotheticB
if(reachable) if(reachable)
{ {
reachabilityMap[hex].push_back(unit); result.push_back(unit);
}
} }
} }
} }
return result;
} }
// avoid blocking path for stronger stack by weaker stack // avoid blocking path for stronger stack by weaker stack

View File

@ -132,6 +132,7 @@ class BattleExchangeEvaluator
private: private:
std::shared_ptr<CBattleInfoCallback> cb; std::shared_ptr<CBattleInfoCallback> cb;
std::shared_ptr<Environment> env; std::shared_ptr<Environment> env;
std::map<uint32_t, ReachabilityInfo> reachabilityCache;
std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap; std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
std::vector<battle::Units> turnOrder; std::vector<battle::Units> turnOrder;
float negativeEffectMultiplier; float negativeEffectMultiplier;
@ -140,6 +141,7 @@ private:
BattleScore calculateExchange( BattleScore calculateExchange(
const AttackPossibility & ap, const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets, PotentialTargets & targets,
DamageCache & damageCache, DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb); std::shared_ptr<HypotheticBattle> hb);
@ -151,7 +153,7 @@ public:
std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<CBattleInfoCallback> cb,
std::shared_ptr<Environment> env, std::shared_ptr<Environment> env,
float strengthRatio): cb(cb), env(env) { float strengthRatio): cb(cb), env(env) {
negativeEffectMultiplier = std::sqrt(strengthRatio); negativeEffectMultiplier = strengthRatio >= 1 ? 1 : strengthRatio;
} }
EvaluationResult findBestTarget( EvaluationResult findBestTarget(
@ -162,10 +164,12 @@ public:
float evaluateExchange( float evaluateExchange(
const AttackPossibility & ap, const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets, PotentialTargets & targets,
DamageCache & damageCache, DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb); std::shared_ptr<HypotheticBattle> hb);
std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex);
void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb); void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb);
ReachabilityData getExchangeUnits( ReachabilityData getExchangeUnits(