1
0
mirror of https://github.com/vcmi/vcmi.git synced 2024-12-22 22:13:35 +02:00

BattleAI: fix bait for archers when need to go long way

This commit is contained in:
Andrii Danylchenko 2024-07-19 13:44:44 +03:00
parent b3fc6743d9
commit 8cdfa26fb5
8 changed files with 183 additions and 30 deletions

View File

@ -213,7 +213,7 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
moveTarget.score, moveTarget.score,
moveTarget.scorePerTurn); moveTarget.scorePerTurn);
return goTowardsNearest(stack, moveTarget.positions); return goTowardsNearest(stack, moveTarget.positions, *targets);
} }
else else
{ {
@ -235,7 +235,7 @@ BattleAction BattleEvaluator::selectStackAction(const CStack * stack)
if(stack->doubleWide() && vstd::contains(brokenWallMoat, stack->getPosition())) if(stack->doubleWide() && vstd::contains(brokenWallMoat, stack->getPosition()))
return BattleAction::makeMove(stack, stack->getPosition().cloneInDirection(BattleHex::RIGHT)); return BattleAction::makeMove(stack, stack->getPosition().cloneInDirection(BattleHex::RIGHT));
else else
return goTowardsNearest(stack, brokenWallMoat); return goTowardsNearest(stack, brokenWallMoat, *targets);
} }
} }
@ -249,7 +249,7 @@ uint64_t timeElapsed(std::chrono::time_point<std::chrono::high_resolution_clock>
return std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count(); return std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
} }
BattleAction BattleEvaluator::goTowardsNearest(const CStack * stack, std::vector<BattleHex> hexes) BattleAction BattleEvaluator::goTowardsNearest(const CStack * stack, std::vector<BattleHex> hexes, const PotentialTargets & targets)
{ {
auto reachability = cb->getBattle(battleID)->getReachability(stack); auto reachability = cb->getBattle(battleID)->getReachability(stack);
auto avHexes = cb->getBattle(battleID)->battleGetAvailableHexes(reachability, stack, false); auto avHexes = cb->getBattle(battleID)->battleGetAvailableHexes(reachability, stack, false);
@ -272,7 +272,27 @@ BattleAction BattleEvaluator::goTowardsNearest(const CStack * stack, std::vector
{ {
if(vstd::contains(avHexes, hex)) if(vstd::contains(avHexes, hex))
{ {
return BattleAction::makeMove(stack, hex); auto additionalScore = 0;
std::optional<AttackPossibility> attackOnTheWay;
for(auto & target : targets.possibleAttacks)
{
if(!target.attack.shooting && target.from == hex && target.attackValue() > additionalScore)
{
additionalScore = target.attackValue();
attackOnTheWay = target;
}
}
if(attackOnTheWay)
{
activeActionMade = true;
return BattleAction::makeMeleeAttack(stack, attackOnTheWay->attack.defender->getPosition(), attackOnTheWay->from);
}
else
{
return BattleAction::makeMove(stack, hex);
}
} }
if(stack->coversPos(hex)) if(stack->coversPos(hex))

View File

@ -43,7 +43,7 @@ public:
bool attemptCastingSpell(const CStack * stack); bool attemptCastingSpell(const CStack * stack);
bool canCastSpell(); bool canCastSpell();
std::optional<PossibleSpellcast> findBestCreatureSpell(const CStack * stack); std::optional<PossibleSpellcast> findBestCreatureSpell(const CStack * stack);
BattleAction goTowardsNearest(const CStack * stack, std::vector<BattleHex> hexes); BattleAction goTowardsNearest(const CStack * stack, std::vector<BattleHex> hexes, const PotentialTargets & targets);
std::vector<BattleHex> getBrokenWallMoatHexes() const; std::vector<BattleHex> getBrokenWallMoatHexes() const;
void evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps); //for offensive damaging spells only void evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps); //for offensive damaging spells only
void print(const std::string & text) const; void print(const std::string & text) const;

View File

@ -277,6 +277,36 @@ EvaluationResult BattleExchangeEvaluator::findBestTarget(
return result; return result;
} }
ReachabilityInfo getReachabilityWithEnemyBypass(
const battle::Unit * activeStack,
DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> state)
{
ReachabilityInfo::Parameters params(activeStack, activeStack->getPosition());
if(!params.flying)
{
for(const auto * unit : state->battleAliveUnits())
{
if(unit->unitSide() == activeStack->unitSide())
continue;
auto dmg = damageCache.getOriginalDamage(activeStack, unit, state);
auto turnsToKill = unit->getAvailableHealth() / dmg + 1;
vstd::amin(turnsToKill, 100);
for(auto & hex : unit->getHexes())
if(hex.isAvailable()) //towers can have <0 pos; we don't also want to overwrite side columns
params.destructibleEnemyTurns[hex] = turnsToKill * unit->getMovementRange();
}
params.bypassEnemyStacks = true;
}
return state->getReachability(params);
}
MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable( MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
const battle::Unit * activeStack, const battle::Unit * activeStack,
PotentialTargets & targets, PotentialTargets & targets,
@ -286,6 +316,8 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
MoveTarget result; MoveTarget result;
BattleExchangeVariant ev; BattleExchangeVariant ev;
logAi->trace("Find move towards unreachable. Enemies count %d", targets.unreachableEnemies.size());
if(targets.unreachableEnemies.empty()) if(targets.unreachableEnemies.empty())
return result; return result;
@ -296,17 +328,17 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
updateReachabilityMap(hb); updateReachabilityMap(hb);
auto dists = cb->getReachability(activeStack); auto dists = getReachabilityWithEnemyBypass(activeStack, damageCache, hb);
auto flying = activeStack->hasBonusOfType(BonusType::FLYING);
for(const battle::Unit * enemy : targets.unreachableEnemies) for(const battle::Unit * enemy : targets.unreachableEnemies)
{ {
std::vector<const battle::Unit *> adjacentStacks = getAdjacentUnits(enemy); logAi->trace(
auto closestStack = *vstd::minElementByFun(adjacentStacks, [&](const battle::Unit * u) -> int64_t "Checking movement towards %d of %s",
{ enemy->getCount(),
return dists.distToNearestNeighbour(activeStack, u) * 100000 - activeStack->getTotalHealth(); enemy->creatureId().toCreature()->getNameSingularTranslated());
});
auto distance = dists.distToNearestNeighbour(activeStack, closestStack); auto distance = dists.distToNearestNeighbour(activeStack, enemy);
if(distance >= GameConstants::BFIELD_SIZE) if(distance >= GameConstants::BFIELD_SIZE)
continue; continue;
@ -315,30 +347,94 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable(
continue; continue;
auto turnsToRich = (distance - 1) / speed + 1; auto turnsToRich = (distance - 1) / speed + 1;
auto hexes = closestStack->getSurroundingHexes(); auto hexes = enemy->getSurroundingHexes();
auto enemySpeed = closestStack->getMovementRange(); auto enemySpeed = enemy->getMovementRange();
auto speedRatio = speed / static_cast<float>(enemySpeed); auto speedRatio = speed / static_cast<float>(enemySpeed);
auto multiplier = speedRatio > 1 ? 1 : speedRatio; auto multiplier = speedRatio > 1 ? 1 : speedRatio;
if(enemy->canShoot()) if(enemy->canShoot())
multiplier *= 1.5f; multiplier *= 1.5f;
for(auto hex : hexes) for(auto & hex : hexes)
{ {
// FIXME: provide distance info for Jousting bonus // FIXME: provide distance info for Jousting bonus
auto bai = BattleAttackInfo(activeStack, closestStack, 0, cb->battleCanShoot(activeStack)); auto bai = BattleAttackInfo(activeStack, enemy, 0, cb->battleCanShoot(activeStack));
auto attack = AttackPossibility::evaluate(bai, hex, damageCache, hb); auto attack = AttackPossibility::evaluate(bai, hex, damageCache, hb);
attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure attack.shootersBlockedDmg = 0; // we do not want to count on it, it is not for sure
auto score = calculateExchange(attack, turnsToRich, targets, damageCache, hb); auto score = calculateExchange(attack, turnsToRich, targets, damageCache, hb);
auto scorePerTurn = BattleScore(score.enemyDamageReduce * std::sqrt(multiplier / turnsToRich), score.ourDamageReduce); auto scorePerTurn = BattleScore(score.enemyDamageReduce * multiplier / turnsToRich, score.ourDamageReduce);
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("Multiplier: %f, turns: %d, current score %f, new score %f", multiplier, turnsToRich, result.scorePerTurn, scoreValue(scorePerTurn));
#endif
if(result.scorePerTurn < scoreValue(scorePerTurn)) if(result.scorePerTurn < scoreValue(scorePerTurn))
{ {
result.scorePerTurn = scoreValue(scorePerTurn); result.scorePerTurn = scoreValue(scorePerTurn);
result.score = scoreValue(score); result.score = scoreValue(score);
result.positions = closestStack->getAttackableHexes(activeStack); result.positions.clear();
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("New high score");
#endif
for(BattleHex enemyHex : enemy->getAttackableHexes(activeStack))
{
while(!flying && dists.distances[enemyHex] > speed)
{
enemyHex = dists.predecessors.at(enemyHex);
if(dists.accessibility[enemyHex] == EAccessibility::ALIVE_STACK)
{
auto defenderToBypass = hb->battleGetUnitByPos(enemyHex);
if(defenderToBypass)
{
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("Found target to bypass at %d", enemyHex.hex);
#endif
auto attackHex = dists.predecessors[enemyHex];
auto baiBypass = BattleAttackInfo(activeStack, defenderToBypass, 0, cb->battleCanShoot(activeStack));
auto attackBypass = AttackPossibility::evaluate(baiBypass, attackHex, damageCache, hb);
auto adjacentStacks = getAdjacentUnits(enemy);
adjacentStacks.push_back(defenderToBypass);
vstd::removeDuplicates(adjacentStacks);
auto bypassScore = calculateExchange(
attackBypass,
dists.distances[attackHex],
targets,
damageCache,
hb,
adjacentStacks);
if(scoreValue(bypassScore) > result.score)
{
auto newMultiplier = multiplier * speed * turnsToRich / dists.distances[attackHex];
result.score = scoreValue(bypassScore);
scorePerTurn = BattleScore(
score.enemyDamageReduce * newMultiplier,
score.ourDamageReduce);
result.scorePerTurn = scoreValue(scorePerTurn);
#if BATTLE_TRACE_LEVEL >= 1
logAi->trace("New high score after bypass %f", scoreValue(scorePerTurn));
#endif
}
}
}
}
result.positions.push_back(enemyHex);
}
result.cachedAttack = attack; result.cachedAttack = attack;
result.turnsToRich = turnsToRich; result.turnsToRich = turnsToRich;
} }
@ -382,7 +478,8 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
const AttackPossibility & ap, const AttackPossibility & ap,
uint8_t turn, uint8_t turn,
PotentialTargets & targets, PotentialTargets & targets,
std::shared_ptr<HypotheticBattle> hb) const std::shared_ptr<HypotheticBattle> hb,
std::vector<const battle::Unit *> additionalUnits) const
{ {
ReachabilityData result; ReachabilityData result;
@ -390,7 +487,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
if(!ap.attack.shooting) hexes.push_back(ap.from); if(!ap.attack.shooting) hexes.push_back(ap.from);
std::vector<const battle::Unit *> allReachableUnits; std::vector<const battle::Unit *> allReachableUnits = additionalUnits;
for(auto hex : hexes) for(auto hex : hexes)
{ {
@ -432,7 +529,7 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits(
for(auto unit : allReachableUnits) for(auto unit : allReachableUnits)
{ {
auto accessible = !unit->canShoot(); auto accessible = !unit->canShoot() || vstd::contains(additionalUnits, unit);
if(!accessible) if(!accessible)
{ {
@ -494,7 +591,8 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
uint8_t turn, uint8_t turn,
PotentialTargets & targets, PotentialTargets & targets,
DamageCache & damageCache, DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb) const std::shared_ptr<HypotheticBattle> hb,
std::vector<const battle::Unit *> additionalUnits) const
{ {
#if BATTLE_TRACE_LEVEL>=1 #if BATTLE_TRACE_LEVEL>=1
logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex); logAi->trace("Battle exchange at %d", ap.attack.shooting ? ap.dest.hex : ap.from.hex);
@ -513,7 +611,7 @@ BattleScore BattleExchangeEvaluator::calculateExchange(
if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive()) if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive())
enemyStacks.push_back(ap.attack.defender); enemyStacks.push_back(ap.attack.defender);
ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb); ReachabilityData exchangeUnits = getExchangeUnits(ap, turn, targets, hb, additionalUnits);
if(exchangeUnits.units.empty()) if(exchangeUnits.units.empty())
{ {

View File

@ -139,7 +139,8 @@ private:
uint8_t turn, uint8_t turn,
PotentialTargets & targets, PotentialTargets & targets,
DamageCache & damageCache, DamageCache & damageCache,
std::shared_ptr<HypotheticBattle> hb) const; std::shared_ptr<HypotheticBattle> hb,
std::vector<const battle::Unit *> additionalUnits = {}) const;
bool canBeHitThisTurn(const AttackPossibility & ap); bool canBeHitThisTurn(const AttackPossibility & ap);
@ -171,7 +172,8 @@ public:
const AttackPossibility & ap, const AttackPossibility & ap,
uint8_t turn, uint8_t turn,
PotentialTargets & targets, PotentialTargets & targets,
std::shared_ptr<HypotheticBattle> hb) const; std::shared_ptr<HypotheticBattle> hb,
std::vector<const battle::Unit *> additionalUnits = {}) const;
bool checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * unit, BattleHex position); bool checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * unit, BattleHex position);

View File

@ -18,9 +18,19 @@ VCMI_LIB_NAMESPACE_BEGIN
bool AccessibilityInfo::tileAccessibleWithGate(BattleHex tile, BattleSide side) const bool AccessibilityInfo::tileAccessibleWithGate(BattleHex tile, BattleSide side) const
{ {
//at(otherHex) != EAccessibility::ACCESSIBLE && (at(otherHex) != EAccessibility::GATE || side != BattleSide::DEFENDER) //at(otherHex) != EAccessibility::ACCESSIBLE && (at(otherHex) != EAccessibility::GATE || side != BattleSide::DEFENDER)
if(at(tile) != EAccessibility::ACCESSIBLE) auto accessibility = at(tile);
if(at(tile) != EAccessibility::GATE || side != BattleSide::DEFENDER)
if(accessibility == EAccessibility::ALIVE_STACK)
{
auto destructible = destructibleEnemyTurns.find(tile);
return destructible != destructibleEnemyTurns.end();
}
if(accessibility != EAccessibility::ACCESSIBLE)
if(accessibility != EAccessibility::GATE || side != BattleSide::DEFENDER)
return false; return false;
return true; return true;
} }

View File

@ -35,6 +35,8 @@ using TAccessibilityArray = std::array<EAccessibility, GameConstants::BFIELD_SIZ
struct DLL_LINKAGE AccessibilityInfo : TAccessibilityArray struct DLL_LINKAGE AccessibilityInfo : TAccessibilityArray
{ {
std::map<BattleHex, ui8> destructibleEnemyTurns;
public: public:
bool accessible(BattleHex tile, const battle::Unit * stack) const; //checks for both tiles if stack is double wide bool accessible(BattleHex tile, const battle::Unit * stack) const; //checks for both tiles if stack is double wide
bool accessible(BattleHex tile, bool doubleWide, BattleSide side) const; //checks for both tiles if stack is double wide bool accessible(BattleHex tile, bool doubleWide, BattleSide side) const; //checks for both tiles if stack is double wide

View File

@ -1052,16 +1052,29 @@ ReachabilityInfo CBattleInfoCallback::makeBFS(const AccessibilityInfo &accessibi
continue; continue;
const int costToNeighbour = ret.distances[curHex.hex] + 1; const int costToNeighbour = ret.distances[curHex.hex] + 1;
for(BattleHex neighbour : BattleHex::neighbouringTilesCache[curHex.hex]) for(BattleHex neighbour : BattleHex::neighbouringTilesCache[curHex.hex])
{ {
if(neighbour.isValid()) if(neighbour.isValid())
{ {
auto additionalCost = 0;
if(params.bypassEnemyStacks)
{
auto enemyToBypass = params.destructibleEnemyTurns.find(neighbour);
if(enemyToBypass != params.destructibleEnemyTurns.end())
{
additionalCost = enemyToBypass->second;
}
}
const int costFoundSoFar = ret.distances[neighbour.hex]; const int costFoundSoFar = ret.distances[neighbour.hex];
if(accessibleCache[neighbour.hex] && costToNeighbour < costFoundSoFar) if(accessibleCache[neighbour.hex] && costToNeighbour + additionalCost < costFoundSoFar)
{ {
hexq.push(neighbour); hexq.push(neighbour);
ret.distances[neighbour.hex] = costToNeighbour; ret.distances[neighbour.hex] = costToNeighbour + additionalCost;
ret.predecessors[neighbour.hex] = curHex; ret.predecessors[neighbour.hex] = curHex;
} }
} }
@ -1236,7 +1249,13 @@ ReachabilityInfo CBattleInfoCallback::getReachability(const ReachabilityInfo::Pa
if(params.flying) if(params.flying)
return getFlyingReachability(params); return getFlyingReachability(params);
else else
return makeBFS(getAccessibility(params.knownAccessible), params); {
auto accessibility = getAccessibility(params.knownAccessible);
accessibility.destructibleEnemyTurns = params.destructibleEnemyTurns;
return makeBFS(accessibility, params);
}
} }
ReachabilityInfo CBattleInfoCallback::getFlyingReachability(const ReachabilityInfo::Parameters &params) const ReachabilityInfo CBattleInfoCallback::getFlyingReachability(const ReachabilityInfo::Parameters &params) const

View File

@ -29,7 +29,9 @@ struct DLL_LINKAGE ReachabilityInfo
bool doubleWide = false; bool doubleWide = false;
bool flying = false; bool flying = false;
bool ignoreKnownAccessible = false; //Ignore obstacles if it is in accessible hexes bool ignoreKnownAccessible = false; //Ignore obstacles if it is in accessible hexes
bool bypassEnemyStacks = false; // in case of true will count amount of turns needed to kill enemy and thus move forward
std::vector<BattleHex> knownAccessible; //hexes that will be treated as accessible, even if they're occupied by stack (by default - tiles occupied by stack we do reachability for, so it doesn't block itself) std::vector<BattleHex> knownAccessible; //hexes that will be treated as accessible, even if they're occupied by stack (by default - tiles occupied by stack we do reachability for, so it doesn't block itself)
std::map<BattleHex, ui8> destructibleEnemyTurns; // hom many turns it is needed to kill enemy on specific hex
BattleHex startPosition; //assumed position of stack BattleHex startPosition; //assumed position of stack
BattleSide perspective = BattleSide::ALL_KNOWING; //some obstacles (eg. quicksands) may be invisible for some side BattleSide perspective = BattleSide::ALL_KNOWING; //some obstacles (eg. quicksands) may be invisible for some side