1
0
mirror of https://github.com/vcmi/vcmi.git synced 2024-12-24 22:14:36 +02:00

BattleAI: bigger reachability map

This commit is contained in:
Andrii Danylchenko 2023-10-13 22:04:26 +03:00
parent 9eb9404f28
commit 870fbd50e3
4 changed files with 115 additions and 50 deletions

View File

@ -157,13 +157,9 @@ float AttackPossibility::calculateDamageReduce(
auto enemyDamageBeforeAttack = damageCache.getOriginalDamage(defender, attackerUnitForMeasurement, state);
auto enemiesKilled = damageDealt / maxHealth + (damageDealt % maxHealth >= defender->getFirstHPleft() ? 1 : 0);
auto damagePerEnemy = enemyDamageBeforeAttack / (double)defender->getCount();
// lets use cached maxHealth here instead of getAvailableHealth
auto firstUnitHpLeft = (availableHealth - damageDealt) % maxHealth;
auto firstUnitHealthRatio = firstUnitHpLeft == 0 ? 1 : static_cast<float>(firstUnitHpLeft) / maxHealth;
auto firstUnitKillValue = (1 - firstUnitHealthRatio) * (1 - firstUnitHealthRatio);
auto lastUnitKillValue = (damageDealt % maxHealth) / (double)maxHealth;;
return damagePerEnemy * (enemiesKilled + firstUnitKillValue * HEALTH_BOUNTY);
return damagePerEnemy * (enemiesKilled + lastUnitKillValue * HEALTH_BOUNTY);
}
int64_t AttackPossibility::evaluateBlockedShootersDmg(

View File

@ -149,7 +149,7 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
bestAttack.attack.attacker->speed(0, true),
bestAttack.defenderDamageReduce,
bestAttack.attackerDamageReduce,
bestAttack.attackValue()
score
);
if (moveTarget.scorePerTurn <= score)
@ -580,7 +580,7 @@ bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack)
}
else
{
ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, *targets, innerCache, state);
ps.value = scoreEvaluator.evaluateExchange(*cachedAttack, 0, *targets, innerCache, state);
}
for(auto unit : allUnits)

View File

@ -116,6 +116,10 @@ float BattleExchangeVariant::trackAttack(
}
}
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("ap shooters blocking: %lld", ap.shootersBlockedDmg);
#endif
attackValue += ap.shootersBlockedDmg;
dpsScore.enemyDamageReduce += ap.shootersBlockedDmg;
attacker->afterAttack(ap.attack.shooting, false);
@ -233,13 +237,17 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
for(auto & ap : targets.possibleAttacks)
{
float score = evaluateExchange(ap, targets, damageCache, hbWaited);
float score = evaluateExchange(ap, 0, targets, damageCache, hbWaited);
if(score > result.score)
{
result.score = score;
result.bestAttack = ap;
result.wait = true;
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("New high score %2f", result.score);
#endif
}
}
}
@ -258,13 +266,17 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
for(auto & ap : targets.possibleAttacks)
{
float score = evaluateExchange(ap, targets, damageCache, hb);
float score = evaluateExchange(ap, 0, targets, damageCache, hb);
if(score >= result.score)
if(score > result.score || score == result.score && result.wait)
{
result.score = score;
result.bestAttack = ap;
result.wait = false;
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("New high score %2f", result.score);
#endif
}
}
@ -312,7 +324,10 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
auto hexes = closestStack->getSurroundingHexes();
auto enemySpeed = closestStack->speed();
auto speedRatio = speed / static_cast<float>(enemySpeed);
auto penalty = speedRatio > 1 ? 1 : speedRatio;
auto multiplier = speedRatio > 1 ? 1 : speedRatio;
if(enemy->canShoot())
multiplier *= 1.5f;
for(auto hex : hexes)
{
@ -323,7 +338,7 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
auto score = calculateExchange(attack, turnsToRich, targets, damageCache, hb);
auto scorePerTurn = BattleScore(score.enemyDamageReduce * std::sqrt(penalty / turnsToRich), score.ourDamageReduce);
auto scorePerTurn = BattleScore(score.enemyDamageReduce * std::sqrt(multiplier / turnsToRich), score.ourDamageReduce);
if(result.scorePerTurn < scoreValue(scorePerTurn))
{
@ -371,12 +386,13 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getAdjacentUnits(cons
ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets,
std::shared_ptr<HypotheticBattle> hb)
{
ReachabilityData result;
auto hexes = ap.attack.defender->getHexes();
auto hexes = ap.attack.defender->getSurroundingHexes();
if(!ap.attack.shooting) hexes.push_back(ap.from);
@ -384,7 +400,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
for(auto hex : hexes)
{
vstd::concatenate(allReachableUnits, reachabilityMap[hex]);
vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap[hex] : getOneTurnReachableUnits(turn, hex));
}
vstd::removeDuplicates(allReachableUnits);
@ -460,17 +476,33 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
float BattleExchangeEvaluator::evaluateExchange(
const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets,
DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb)
{
BattleScore score = calculateExchange(ap, targets, damageCache, hb);
if(ap.from.hex == 127)
{
logAi->trace("x");
}
BattleScore score = calculateExchange(ap, turn, targets, damageCache, hb);
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace(
"calculateExchange score +%2f -%2fx%2f = %2f",
score.enemyDamageReduce,
score.ourDamageReduce,
getNegativeEffectMultiplier(),
scoreValue(score));
#endif
return scoreValue(score);
}
BattleScore BattleExchangeEvaluator::calculateExchange(
const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets,
DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb)
@ -492,8 +524,6 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
enemyStacks.push_back(ap.attack.defender);
vstd::amin(turn, reachabilityMapByTurns.size() - 1);
ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb);
if(exchangeUnits.units.empty())
@ -511,10 +541,15 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true);
auto & attackerQueue = isOur ? ourStacks : enemyStacks;
auto u = exchangeBattle->getForUpdate(unit->unitId());
if(exchangeBattle->getForUpdate(unit->unitId())->alive() && !vstd::contains(attackerQueue, unit))
if(u->alive() && !vstd::contains(attackerQueue, unit))
{
attackerQueue.push_back(unit);
#if BATTLE_TRACE_LEVEL
logAi->trace("Exchanging: %s", u->getDescription());
#endif
}
}
@ -562,7 +597,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
true);
#if BATTLE_TRACE_LEVEL>=1
logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), u->getDescription(), score);
logAi->trace("Best target selector %s->%s score = %2f", attacker->getDescription(), stackWithBonuses->getDescription(), score);
#endif
return score;
@ -645,6 +680,13 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
// avoid blocking path for stronger stack by weaker stack
// the method checks if all stacks can be placed around enemy
std::map<BattleHex, battle::Units> reachabilityMap;
auto hexes = ap.attack.defender->getSurroundingHexes();
for(auto hex : hexes)
reachabilityMap[hex] = getOneTurnReachableUnits(turn, hex);
v.adjustPositions(melleeAttackers, ap, reachabilityMap);
#if BATTLE_TRACE_LEVEL>=1
@ -736,18 +778,38 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap)
return false;
}
void BattleExchangeEvaluator::updateReachabilityMap( std::shared_ptr<HypotheticBattle> hb)
void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb)
{
const int TURN_DEPTH = 2;
turnOrder.clear();
hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
reachabilityMap.clear();
for(int turn = 0; turn < turnOrder.size(); turn++)
hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH);
for(auto turn : turnOrder)
{
auto & turnQueue = turnOrder[turn];
for(auto u : turn)
{
if(!vstd::contains(reachabilityCache, u->unitId()))
{
reachabilityCache[u->unitId()] = hb->getReachability(u);
}
}
}
for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
{
reachabilityMap[hex] = getOneTurnReachableUnits(0, hex);
}
}
std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex)
{
std::vector<const battle::Unit *> result;
for(int i = 0; i < turnOrder.size(); i++, turn++)
{
auto & turnQueue = turnOrder[i];
HypotheticBattle turnBattle(env.get(), cb);
for(const battle::Unit * unit : turnQueue)
@ -755,46 +817,49 @@ void BattleExchangeEvaluator::updateReachabilityMap( std::shared_ptr<HypotheticB
if(unit->isTurret())
continue;
auto unitSpeed = unit->speed(turn);
if(turnBattle.battleCanShoot(unit))
{
for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
{
reachabilityMap[hex].push_back(unit);
}
result.push_back(unit);
continue;
}
auto unitReachability = turnBattle.getReachability(unit);
auto unitSpeed = unit->speed(turn);
auto radius = unitSpeed * (turn + 1);
for(BattleHex hex = BattleHex::TOP_LEFT; hex.isValid(); hex = hex + 1)
{
bool reachable = unitReachability.distances[hex] <= unitSpeed;
if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
ReachabilityInfo unitReachability = vstd::getOrCompute(
reachabilityCache,
unit->unitId(),
[&](ReachabilityInfo & data)
{
const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
data = turnBattle.getReachability(unit);
});
if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
bool reachable = unitReachability.distances[hex] <= radius;
if(!reachable && unitReachability.accessibility[hex] == EAccessibility::ALIVE_STACK)
{
const battle::Unit * hexStack = cb->battleGetUnitByPos(hex);
if(hexStack && cb->battleMatchOwner(unit, hexStack, false))
{
for(BattleHex neighbor : hex.neighbouringTiles())
{
for(BattleHex neighbor : hex.neighbouringTiles())
{
reachable = unitReachability.distances[neighbor] <= unitSpeed;
reachable = unitReachability.distances[neighbor] <= radius;
if(reachable) break;
}
if(reachable) break;
}
}
}
if(reachable)
{
reachabilityMap[hex].push_back(unit);
}
if(reachable)
{
result.push_back(unit);
}
}
}
return result;
}
// avoid blocking path for stronger stack by weaker stack

View File

@ -132,6 +132,7 @@ class BattleExchangeEvaluator
private:
std::shared_ptr<CBattleInfoCallback> cb;
std::shared_ptr<Environment> env;
std::map<uint32_t, ReachabilityInfo> reachabilityCache;
std::map<BattleHex, std::vector<const battle::Unit *>> reachabilityMap;
std::vector<battle::Units> turnOrder;
float negativeEffectMultiplier;
@ -140,6 +141,7 @@ private:
BattleScore calculateExchange(
const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets,
DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb);
@ -151,7 +153,7 @@ public:
std::shared_ptr<CBattleInfoCallback> cb,
std::shared_ptr<Environment> env,
float strengthRatio): cb(cb), env(env) {
negativeEffectMultiplier = std::sqrt(strengthRatio);
negativeEffectMultiplier = strengthRatio >= 1 ? 1 : strengthRatio;
}
EvaluationResult findBestTarget(
@ -162,10 +164,12 @@ public:
float evaluateExchange(
const AttackPossibility & ap,
uint8_t turn,
PotentialTargets & targets,
DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb);
std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex);
void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb);
ReachabilityData getExchangeUnits(