diff --git a/AI/BattleAI/AttackPossibility.cpp b/AI/BattleAI/AttackPossibility.cpp index 496f73414..ce1803c34 100644 --- a/AI/BattleAI/AttackPossibility.cpp +++ b/AI/BattleAI/AttackPossibility.cpp @@ -103,6 +103,12 @@ int64_t AttackPossibility::damageDiff() const return defenderDamageReduce - attackerDamageReduce - collateralDamageReduce + shootersBlockedDmg; } +int64_t AttackPossibility::damageDiff(float positiveEffectMultiplier, float negativeEffectMultiplier) const +{ + return positiveEffectMultiplier * (defenderDamageReduce + shootersBlockedDmg) + - negativeEffectMultiplier * (attackerDamageReduce + collateralDamageReduce); +} + int64_t AttackPossibility::attackValue() const { return damageDiff(); @@ -121,9 +127,6 @@ int64_t AttackPossibility::calculateDamageReduce( std::shared_ptr state) { const float HEALTH_BOUNTY = 0.5; - const float KILL_BOUNTY = 1.0 - HEALTH_BOUNTY; - - vstd::amin(damageDealt, defender->getAvailableHealth()); // FIXME: provide distance info for Jousting bonus auto attackerUnitForMeasurement = attacker; @@ -146,11 +149,21 @@ int64_t AttackPossibility::calculateDamageReduce( attackerUnitForMeasurement = ourUnits.front(); } - auto enemyDamageBeforeAttack = damageCache.getOriginalDamage(defender, attackerUnitForMeasurement, state); - auto enemiesKilled = damageDealt / defender->getMaxHealth() + (damageDealt % defender->getMaxHealth() >= defender->getFirstHPleft() ? 1 : 0); - auto damagePerEnemy = enemyDamageBeforeAttack / (double)defender->getCount(); + auto maxHealth = defender->getMaxHealth(); + auto availableHealth = defender->getFirstHPleft() + ((defender->getCount() - 1) * maxHealth); - return (int64_t)(damagePerEnemy * (enemiesKilled * KILL_BOUNTY + damageDealt * HEALTH_BOUNTY / (double)defender->getMaxHealth())); + vstd::amin(damageDealt, availableHealth); + + auto enemyDamageBeforeAttack = damageCache.getOriginalDamage(defender, attackerUnitForMeasurement, state); + auto enemiesKilled = damageDealt / maxHealth + (damageDealt % maxHealth >= defender->getFirstHPleft() ? 1 : 0); + auto damagePerEnemy = enemyDamageBeforeAttack / (double)defender->getCount(); + + // lets use cached maxHealth here instead of getAvailableHealth + auto firstUnitHpLeft = (availableHealth - damageDealt) % maxHealth; + auto firstUnitHealthRatio = firstUnitHpLeft == 0 ? 1 : static_cast(firstUnitHpLeft) / maxHealth; + auto firstUnitKillValue = (1 - firstUnitHealthRatio) * (1 - firstUnitHealthRatio); + + return (int64_t)(damagePerEnemy * (enemiesKilled + firstUnitKillValue * HEALTH_BOUNTY)); } int64_t AttackPossibility::evaluateBlockedShootersDmg( diff --git a/AI/BattleAI/AttackPossibility.h b/AI/BattleAI/AttackPossibility.h index f96f59a6d..cd28839f4 100644 --- a/AI/BattleAI/AttackPossibility.h +++ b/AI/BattleAI/AttackPossibility.h @@ -55,6 +55,7 @@ public: int64_t damageDiff() const; int64_t attackValue() const; + int64_t damageDiff(float positiveEffectMultiplier, float negativeEffectMultiplier) const; static AttackPossibility evaluate( const BattleAttackInfo & attackInfo, diff --git a/AI/BattleAI/BattleAI.cpp b/AI/BattleAI/BattleAI.cpp index d5e59b6d7..60e2729f6 100644 --- a/AI/BattleAI/BattleAI.cpp +++ b/AI/BattleAI/BattleAI.cpp @@ -9,6 +9,7 @@ */ #include "StdInc.h" #include "BattleAI.h" +#include "BattleEvaluator.h" #include "BattleExchangeVariant.h" #include "StackWithBonuses.h" @@ -29,90 +30,6 @@ #define LOGL(text) print(text) #define LOGFL(text, formattingEl) print(boost::str(boost::format(text) % formattingEl)) -enum class SpellTypes -{ - ADVENTURE, BATTLE, OTHER -}; - -SpellTypes spellType(const CSpell * spell) -{ - if(!spell->isCombat() || spell->isCreatureAbility()) - return SpellTypes::OTHER; - - if(spell->isOffensive() || spell->hasEffects() || spell->hasBattleEffects()) - return SpellTypes::BATTLE; - - return SpellTypes::OTHER; -} - -class BattleEvaluator -{ - std::unique_ptr targets; - std::shared_ptr hb; - BattleExchangeEvaluator scoreEvaluator; - std::shared_ptr cb; - std::shared_ptr env; - bool activeActionMade = false; - std::optional cachedAttack; - PlayerColor playerID; - int side; - int64_t cachedScore; - DamageCache damageCache; - -public: - BattleAction selectStackAction(const CStack * stack); - void attemptCastingSpell(const CStack * stack); - std::optional findBestCreatureSpell(const CStack * stack); - BattleAction goTowardsNearest(const CStack * stack, std::vector hexes); - std::vector getBrokenWallMoatHexes() const; - void evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps); //for offensive damaging spells only - void print(const std::string & text) const; - - BattleEvaluator(std::shared_ptr env, std::shared_ptr cb, const battle::Unit * activeStack, PlayerColor playerID, int side) - :scoreEvaluator(cb, env), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb) - { - hb = std::make_shared(env.get(), cb); - damageCache.buildDamageCache(hb, side); - - targets = std::make_unique(activeStack, damageCache, hb); - cachedScore = EvaluationResult::INEFFECTIVE_SCORE; - } - - BattleEvaluator( - std::shared_ptr env, - std::shared_ptr cb, - std::shared_ptr hb, - DamageCache & damageCache, - const battle::Unit * activeStack, - PlayerColor playerID, - int side) - :scoreEvaluator(cb, env), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), hb(hb), damageCache(damageCache) - { - targets = std::make_unique(activeStack, damageCache, hb); - cachedScore = EvaluationResult::INEFFECTIVE_SCORE; - } -}; - -std::vector BattleEvaluator::getBrokenWallMoatHexes() const -{ - std::vector result; - - for(EWallPart wallPart : { EWallPart::BOTTOM_WALL, EWallPart::BELOW_GATE, EWallPart::OVER_GATE, EWallPart::UPPER_WALL }) - { - auto state = cb->battleGetWallState(wallPart); - - if(state != EWallState::DESTROYED) - continue; - - auto wallHex = cb->wallPartToBattleHex((EWallPart)wallPart); - auto moatHex = wallHex.cloneInDirection(BattleHex::LEFT); - - result.push_back(moatHex); - } - - return result; -} - CBattleAI::CBattleAI() : side(-1), wasWaitingForRealize(false), @@ -159,161 +76,22 @@ BattleAction CBattleAI::useHealingTent(const CStack *stack) return BattleAction::makeHeal(stack, woundHpToStack.rbegin()->second); //last element of the woundHpToStack is the most wounded stack } -std::optional BattleEvaluator::findBestCreatureSpell(const CStack *stack) -{ - //TODO: faerie dragon type spell should be selected by server - SpellID creatureSpellToCast = cb->battleGetRandomStackSpell(CRandomGenerator::getDefault(), stack, CBattleInfoCallback::RANDOM_AIMED); - if(stack->hasBonusOfType(BonusType::SPELLCASTER) && stack->canCast() && creatureSpellToCast != SpellID::NONE) - { - const CSpell * spell = creatureSpellToCast.toSpell(); - - if(spell->canBeCast(getCbc().get(), spells::Mode::CREATURE_ACTIVE, stack)) - { - std::vector possibleCasts; - spells::BattleCast temp(getCbc().get(), stack, spells::Mode::CREATURE_ACTIVE, spell); - for(auto & target : temp.findPotentialTargets()) - { - PossibleSpellcast ps; - ps.dest = target; - ps.spell = spell; - evaluateCreatureSpellcast(stack, ps); - possibleCasts.push_back(ps); - } - - std::sort(possibleCasts.begin(), possibleCasts.end(), [&](const PossibleSpellcast & lhs, const PossibleSpellcast & rhs) { return lhs.value > rhs.value; }); - if(!possibleCasts.empty() && possibleCasts.front().value > 0) - { - return possibleCasts.front(); - } - } - } - return std::nullopt; -} - -BattleAction BattleEvaluator::selectStackAction(const CStack * stack) -{ - //evaluate casting spell for spellcasting stack - std::optional bestSpellcast = findBestCreatureSpell(stack); - - auto moveTarget = scoreEvaluator.findMoveTowardsUnreachable(stack, *targets, damageCache, hb); - auto score = EvaluationResult::INEFFECTIVE_SCORE; - - if(targets->possibleAttacks.empty() && bestSpellcast.has_value()) - { - activeActionMade = true; - return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id); - } - - if(!targets->possibleAttacks.empty()) - { -#if BATTLE_TRACE_LEVEL>=1 - logAi->trace("Evaluating attack for %s", stack->getDescription()); -#endif - - auto evaluationResult = scoreEvaluator.findBestTarget(stack, *targets, damageCache, hb); - auto & bestAttack = evaluationResult.bestAttack; - - cachedAttack = bestAttack; - cachedScore = evaluationResult.score; - - //TODO: consider more complex spellcast evaluation, f.e. because "re-retaliation" during enemy move in same turn for melee attack etc. - if(bestSpellcast.has_value() && bestSpellcast->value > bestAttack.damageDiff()) - { - // return because spellcast value is damage dealt and score is dps reduce - activeActionMade = true; - return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id); - } - - if(evaluationResult.score > score) - { - score = evaluationResult.score; - - logAi->debug("BattleAI: %s -> %s x %d, from %d curpos %d dist %d speed %d: +%lld -%lld = %lld", - bestAttack.attackerState->unitType()->getJsonKey(), - bestAttack.affectedUnits[0]->unitType()->getJsonKey(), - (int)bestAttack.affectedUnits[0]->getCount(), - (int)bestAttack.from, - (int)bestAttack.attack.attacker->getPosition().hex, - bestAttack.attack.chargeDistance, - bestAttack.attack.attacker->speed(0, true), - bestAttack.defenderDamageReduce, - bestAttack.attackerDamageReduce, bestAttack.attackValue() - ); - - if (moveTarget.scorePerTurn <= score) - { - if(evaluationResult.wait) - { - return BattleAction::makeWait(stack); - } - else if(bestAttack.attack.shooting) - { - activeActionMade = true; - return BattleAction::makeShotAttack(stack, bestAttack.attack.defender); - } - else - { - activeActionMade = true; - return BattleAction::makeMeleeAttack(stack, bestAttack.attack.defender->getPosition(), bestAttack.from); - } - } - } - } - - //ThreatMap threatsToUs(stack); // These lines may be usefull but they are't used in the code. - if(moveTarget.scorePerTurn > score) - { - score = moveTarget.score; - cachedAttack = moveTarget.cachedAttack; - cachedScore = score; - - if(stack->waited()) - { - return goTowardsNearest(stack, moveTarget.positions); - } - else - { - return BattleAction::makeWait(stack); - } - } - - if(score <= EvaluationResult::INEFFECTIVE_SCORE - && !stack->hasBonusOfType(BonusType::FLYING) - && stack->unitSide() == BattleSide::ATTACKER - && cb->battleGetSiegeLevel() >= CGTownInstance::CITADEL) - { - auto brokenWallMoat = getBrokenWallMoatHexes(); - - if(brokenWallMoat.size()) - { - activeActionMade = true; - - if(stack->doubleWide() && vstd::contains(brokenWallMoat, stack->getPosition())) - return BattleAction::makeMove(stack, stack->getPosition().cloneInDirection(BattleHex::RIGHT)); - else - return goTowardsNearest(stack, brokenWallMoat); - } - } - - return BattleAction::makeDefend(stack); -} - void CBattleAI::yourTacticPhase(int distance) { cb->battleMakeTacticAction(BattleAction::makeEndOFTacticPhase(cb->battleGetTacticsSide())); } -uint64_t timeElapsed(std::chrono::time_point start) -{ - auto end = std::chrono::high_resolution_clock::now(); - - return std::chrono::duration_cast(end - start).count(); -} - void CBattleAI::activeStack( const CStack * stack ) { LOG_TRACE_PARAMS(logAi, "stack: %s", stack->nodeName()); + auto timeElapsed = [](std::chrono::time_point start) -> uint64_t + { + auto end = std::chrono::high_resolution_clock::now(); + + return std::chrono::duration_cast(end - start).count(); + }; + BattleAction result = BattleAction::makeDefend(stack); setCbc(cb); //TODO: make solid sure that AIs always use their callbacks (need to take care of event handlers too) @@ -332,12 +110,19 @@ void CBattleAI::activeStack( const CStack * stack ) return; } - BattleEvaluator evaluator(env, cb, stack, playerID, side); + BattleEvaluator evaluator(env, cb, stack, playerID, side, strengthRatio); result = evaluator.selectStackAction(stack); - if(evaluator.attemptCastingSpell(stack)) - return; + if(!skipCastUntilNextBattle && evaluator.canCastSpell()) + { + auto spelCasted = evaluator.attemptCastingSpell(stack); + + if(spelCasted) + return; + + skipCastUntilNextBattle = true; + } logAi->trace("Spellcast attempt completed in %lld", timeElapsed(start)); @@ -370,103 +155,6 @@ void CBattleAI::activeStack( const CStack * stack ) cb->battleMakeUnitAction(result); } -BattleAction BattleEvaluator::goTowardsNearest(const CStack * stack, std::vector hexes) -{ - auto reachability = cb->getReachability(stack); - auto avHexes = cb->battleGetAvailableHexes(reachability, stack, false); - - if(!avHexes.size() || !hexes.size()) //we are blocked or dest is blocked - { - return BattleAction::makeDefend(stack); - } - - std::sort(hexes.begin(), hexes.end(), [&](BattleHex h1, BattleHex h2) -> bool - { - return reachability.distances[h1] < reachability.distances[h2]; - }); - - for(auto hex : hexes) - { - if(vstd::contains(avHexes, hex)) - { - return BattleAction::makeMove(stack, hex); - } - - if(stack->coversPos(hex)) - { - logAi->warn("Warning: already standing on neighbouring tile!"); - //We shouldn't even be here... - return BattleAction::makeDefend(stack); - } - } - - BattleHex bestNeighbor = hexes.front(); - - if(reachability.distances[bestNeighbor] > GameConstants::BFIELD_SIZE) - { - return BattleAction::makeDefend(stack); - } - - scoreEvaluator.updateReachabilityMap(hb); - - if(stack->hasBonusOfType(BonusType::FLYING)) - { - std::set obstacleHexes; - - auto insertAffected = [](const CObstacleInstance & spellObst, std::set obstacleHexes) { - auto affectedHexes = spellObst.getAffectedTiles(); - obstacleHexes.insert(affectedHexes.cbegin(), affectedHexes.cend()); - }; - - const auto & obstacles = hb->battleGetAllObstacles(); - - for (const auto & obst: obstacles) { - - if(obst->triggersEffects()) - { - auto triggerAbility = VLC->spells()->getById(obst->getTrigger()); - auto triggerIsNegative = triggerAbility->isNegative() || triggerAbility->isDamage(); - - if(triggerIsNegative) - insertAffected(*obst, obstacleHexes); - } - } - // Flying stack doesn't go hex by hex, so we can't backtrack using predecessors. - // We just check all available hexes and pick the one closest to the target. - auto nearestAvailableHex = vstd::minElementByFun(avHexes, [&](BattleHex hex) -> int - { - const int NEGATIVE_OBSTACLE_PENALTY = 100; // avoid landing on negative obstacle (moat, fire wall, etc) - const int BLOCKED_STACK_PENALTY = 100; // avoid landing on moat - - auto distance = BattleHex::getDistance(bestNeighbor, hex); - - if(vstd::contains(obstacleHexes, hex)) - distance += NEGATIVE_OBSTACLE_PENALTY; - - return scoreEvaluator.checkPositionBlocksOurStacks(*hb, stack, hex) ? BLOCKED_STACK_PENALTY + distance : distance; - }); - - return BattleAction::makeMove(stack, *nearestAvailableHex); - } - else - { - BattleHex currentDest = bestNeighbor; - while(1) - { - if(!currentDest.isValid()) - { - return BattleAction::makeDefend(stack); - } - - if(vstd::contains(avHexes, currentDest) - && !scoreEvaluator.checkPositionBlocksOurStacks(*hb, stack, currentDest)) - return BattleAction::makeMove(stack, currentDest); - - currentDest = reachability.predecessors[currentDest]; - } - } -} - BattleAction CBattleAI::useCatapult(const CStack * stack) { BattleAction attack; @@ -515,348 +203,16 @@ BattleAction CBattleAI::useCatapult(const CStack * stack) return attack; } -<<<<<<< HEAD -bool CBattleAI::attemptCastingSpell() -======= -void BattleEvaluator::attemptCastingSpell(const CStack * activeStack) ->>>>>>> ea22737e9 (BattleAI: damage cache and switch to different model of spells evaluation) -{ - auto hero = cb->battleGetMyHero(); - if(!hero) - return false; - - if(cb->battleCanCastSpell(hero, spells::Mode::HERO) != ESpellCastProblem::OK) - return false; - - LOGL("Casting spells sounds like fun. Let's see..."); - //Get all spells we can cast - std::vector possibleSpells; - vstd::copy_if(VLC->spellh->objects, std::back_inserter(possibleSpells), [hero, this](const CSpell *s) -> bool - { - return s->canBeCast(cb.get(), spells::Mode::HERO, hero); - }); - LOGFL("I can cast %d spells.", possibleSpells.size()); - - vstd::erase_if(possibleSpells, [](const CSpell *s) - { - return spellType(s) != SpellTypes::BATTLE || s->getTargetType() == spells::AimType::LOCATION; - }); - - LOGFL("I know how %d of them works.", possibleSpells.size()); - - //Get possible spell-target pairs - std::vector possibleCasts; - for(auto spell : possibleSpells) - { - spells::BattleCast temp(cb.get(), hero, spells::Mode::HERO, spell); - - if(spell->getTargetType() == spells::AimType::LOCATION) - continue; - - const bool FAST = true; - - for(auto & target : temp.findPotentialTargets(FAST)) - { - PossibleSpellcast ps; - ps.dest = target; - ps.spell = spell; - possibleCasts.push_back(ps); - } - } - LOGFL("Found %d spell-target combinations.", possibleCasts.size()); - if(possibleCasts.empty()) - return false; - - using ValueMap = PossibleSpellcast::ValueMap; - - auto evaluateQueue = [&](ValueMap & values, const std::vector & queue, std::shared_ptr state, size_t minTurnSpan, bool * enemyHadTurnOut) -> bool - { - bool firstRound = true; - bool enemyHadTurn = false; - size_t ourTurnSpan = 0; - - bool stop = false; - - for(auto & round : queue) - { - if(!firstRound) - state->nextRound(0);//todo: set actual value? - for(auto unit : round) - { - if(!vstd::contains(values, unit->unitId())) - values[unit->unitId()] = 0; - - if(!unit->alive()) - continue; - - if(state->battleGetOwner(unit) != playerID) - { - enemyHadTurn = true; - - if(!firstRound || state->battleCastSpells(unit->unitSide()) == 0) - { - //enemy could counter our spell at this point - //anyway, we do not know what enemy will do - //just stop evaluation - stop = true; - break; - } - } - else if(!enemyHadTurn) - { - ourTurnSpan++; - } - - state->nextTurn(unit->unitId()); - - PotentialTargets pt(unit, damageCache, state); - - if(!pt.possibleAttacks.empty()) - { - AttackPossibility ap = pt.bestAction(); - - auto swb = state->getForUpdate(unit->unitId()); - *swb = *ap.attackerState; - - if(ap.defenderDamageReduce > 0) - swb->removeUnitBonus(Bonus::UntilAttack); - if(ap.attackerDamageReduce > 0) - swb->removeUnitBonus(Bonus::UntilBeingAttacked); - - for(auto affected : ap.affectedUnits) - { - swb = state->getForUpdate(affected->unitId()); - *swb = *affected; - - if(ap.defenderDamageReduce > 0) - swb->removeUnitBonus(Bonus::UntilBeingAttacked); - if(ap.attackerDamageReduce > 0 && ap.attack.defender->unitId() == affected->unitId()) - swb->removeUnitBonus(Bonus::UntilAttack); - } - } - - auto bav = pt.bestActionValue(); - - //best action is from effective owner`s point if view, we need to convert to our point if view - if(state->battleGetOwner(unit) != playerID) - bav = -bav; - values[unit->unitId()] += bav; - } - - firstRound = false; - - if(stop) - break; - } - - if(enemyHadTurnOut) - *enemyHadTurnOut = enemyHadTurn; - - return ourTurnSpan >= minTurnSpan; - }; - - ValueMap valueOfStack; - ValueMap healthOfStack; - - TStacks all = cb->battleGetAllStacks(false); - - size_t ourRemainingTurns = 0; - - for(auto unit : all) - { - healthOfStack[unit->unitId()] = unit->getAvailableHealth(); - valueOfStack[unit->unitId()] = 0; - - if(cb->battleGetOwner(unit) == playerID && unit->canMove() && !unit->moved()) - ourRemainingTurns++; - } - - LOGFL("I have %d turns left in this round", ourRemainingTurns); - - const bool castNow = ourRemainingTurns <= 1; - - if(castNow) - print("I should try to cast a spell now"); - else - print("I could wait better moment to cast a spell"); - - auto amount = all.size(); - - std::vector turnOrder; - - cb->battleGetTurnOrder(turnOrder, amount, 2); //no more than 1 turn after current, each unit at least once - - { - bool enemyHadTurn = false; - - auto state = std::make_shared(env.get(), cb); - - evaluateQueue(valueOfStack, turnOrder, state, 0, &enemyHadTurn); - - if(!enemyHadTurn) - { - auto battleIsFinishedOpt = state->battleIsFinished(); - - if(battleIsFinishedOpt) - { - print("No need to cast a spell. Battle will finish soon."); - return false; - } - } - } - - CStopWatch timer; - - tbb::parallel_for(tbb::blocked_range(0, possibleCasts.size()), [&](const tbb::blocked_range & r) - { - for(auto i = r.begin(); i != r.end(); i++) - { - auto & ps = possibleCasts[i]; - auto state = std::make_shared(env.get(), cb); - - spells::BattleCast cast(state.get(), hero, spells::Mode::HERO, ps.spell); - cast.castEval(state->getServerCallback(), ps.dest); - - auto allUnits = state->battleGetUnitsIf([](const battle::Unit * u) -> bool { return true; }); - - auto needFullEval = vstd::contains_if(allUnits, [&](const battle::Unit * u) -> bool - { - auto original = cb->battleGetUnitByID(u->unitId()); - return !original || u->speed() != original->speed(); - }); - - DamageCache innerCache(&damageCache); - innerCache.buildDamageCache(state, side); - - if(needFullEval || !cachedAttack) - { - PotentialTargets innerTargets(activeStack, damageCache, state); - BattleExchangeEvaluator innerEvaluator(state, env); - - if(!innerTargets.possibleAttacks.empty()) - { - innerEvaluator.updateReachabilityMap(state); - - auto newStackAction = innerEvaluator.findBestTarget(activeStack, innerTargets, innerCache, state); - - ps.value = newStackAction.score; - } - else - { - ps.value = 0; - } - } - else - { - ps.value = scoreEvaluator.calculateExchange(*cachedAttack, *targets, innerCache, state); - } - - for(auto unit : allUnits) - { - auto newHealth = unit->getAvailableHealth(); - auto oldHealth = healthOfStack[unit->unitId()]; - - if(oldHealth != newHealth) - { - auto damage = std::abs(oldHealth - newHealth); - auto originalDefender = cb->battleGetUnitByID(unit->unitId()); - auto dpsReduce = AttackPossibility::calculateDamageReduce(nullptr, originalDefender ? originalDefender : unit, damage, innerCache, state); - auto ourUnit = unit->unitSide() == side ? 1 : -1; - auto goodEffect = newHealth > oldHealth ? 1 : -1; - - ps.value += ourUnit * goodEffect * dpsReduce; - } - } - } - }); - - LOGFL("Evaluation took %d ms", timer.getDiff()); - - auto pscValue = [](const PossibleSpellcast &ps) -> int64_t - { - return ps.value; - }; - auto castToPerform = *vstd::maxElementByFun(possibleCasts, pscValue); - - if(castToPerform.value > cachedScore) - { - LOGFL("Best spell is %s (value %d). Will cast.", castToPerform.spell->getNameTranslated() % castToPerform.value); - BattleAction spellcast; - spellcast.actionType = EActionType::HERO_SPELL; - spellcast.spell = castToPerform.spell->getId(); - spellcast.setTarget(castToPerform.dest); - spellcast.side = side; - spellcast.stackNumber = (!side) ? -1 : -2; - cb->battleMakeSpellAction(spellcast); -<<<<<<< HEAD - movesSkippedByDefense = 0; - return true; -======= - activeActionMade = true; ->>>>>>> ea22737e9 (BattleAI: damage cache and switch to different model of spells evaluation) - } - else - { - LOGFL("Best spell is %s. But it is actually useless (value %d).", castToPerform.spell->getNameTranslated() % castToPerform.value); - return false; - } -} - -//Below method works only for offensive spells -void BattleEvaluator::evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps) -{ - using ValueMap = PossibleSpellcast::ValueMap; - - RNGStub rngStub; - HypotheticBattle state(env.get(), cb); - TStacks all = cb->battleGetAllStacks(false); - - ValueMap healthOfStack; - ValueMap newHealthOfStack; - - for(auto unit : all) - { - healthOfStack[unit->unitId()] = unit->getAvailableHealth(); - } - - spells::BattleCast cast(&state, stack, spells::Mode::CREATURE_ACTIVE, ps.spell); - cast.castEval(state.getServerCallback(), ps.dest); - - for(auto unit : all) - { - auto unitId = unit->unitId(); - auto localUnit = state.battleGetUnitByID(unitId); - newHealthOfStack[unitId] = localUnit->getAvailableHealth(); - } - - int64_t totalGain = 0; - - for(auto unit : all) - { - auto unitId = unit->unitId(); - auto localUnit = state.battleGetUnitByID(unitId); - - auto healthDiff = newHealthOfStack[unitId] - healthOfStack[unitId]; - - if(localUnit->unitOwner() != getCbc()->getPlayerID()) - healthDiff = -healthDiff; - - if(healthDiff < 0) - { - ps.value = -1; - return; //do not damage own units at all - } - - totalGain += healthDiff; - } - - ps.value = totalGain; -} - void CBattleAI::battleStart(const CCreatureSet *army1, const CCreatureSet *army2, int3 tile, const CGHeroInstance *hero1, const CGHeroInstance *hero2, bool Side, bool replayAllowed) { LOG_TRACE(logAi); side = Side; + strengthRatio = static_cast(army1->getArmyStrength()) / static_cast(army2->getArmyStrength()); + + if(side == 1) + strengthRatio = 1 / strengthRatio; + + skipCastUntilNextBattle = false; } void CBattleAI::print(const std::string &text) const @@ -864,11 +220,6 @@ void CBattleAI::print(const std::string &text) const logAi->trace("%s Battle AI[%p]: %s", playerID.getStr(), this, text); } -void BattleEvaluator::print(const std::string & text) const -{ - logAi->trace("%s Battle AI[%p]: %s", playerID.getStr(), this, text); -} - std::optional CBattleAI::considerFleeingOrSurrendering() { BattleStateInfoForRetreat bs; diff --git a/AI/BattleAI/BattleAI.h b/AI/BattleAI/BattleAI.h index 73bff11c8..9c60a32b4 100644 --- a/AI/BattleAI/BattleAI.h +++ b/AI/BattleAI/BattleAI.h @@ -62,6 +62,8 @@ class CBattleAI : public CBattleGameInterface bool wasWaitingForRealize; bool wasUnlockingGs; int movesSkippedByDefense; + float strengthRatio; + bool skipCastUntilNextBattle; public: CBattleAI(); diff --git a/AI/BattleAI/BattleEvaluator.cpp b/AI/BattleAI/BattleEvaluator.cpp new file mode 100644 index 000000000..d5d101f1e --- /dev/null +++ b/AI/BattleAI/BattleEvaluator.cpp @@ -0,0 +1,679 @@ +/* + * BattleAI.cpp, part of VCMI engine + * + * Authors: listed in file AUTHORS in main folder + * + * License: GNU General Public License v2.0 or later + * Full text of license available in license.txt file, in main folder + * + */ +#include "StdInc.h" +#include "BattleEvaluator.h" +#include "BattleExchangeVariant.h" + +#include "StackWithBonuses.h" +#include "EnemyInfo.h" +#include "tbb/parallel_for.h" +#include "../../lib/CStopWatch.h" +#include "../../lib/CThreadHelper.h" +#include "../../lib/mapObjects/CGTownInstance.h" +#include "../../lib/spells/CSpellHandler.h" +#include "../../lib/spells/ISpellMechanics.h" +#include "../../lib/battle/BattleStateInfoForRetreat.h" +#include "../../lib/battle/CObstacleInstance.h" +#include "../../lib/battle/BattleAction.h" + +// TODO: remove +// Eventually only IBattleInfoCallback and battle::Unit should be used, +// CUnitState should be private and CStack should be removed completely +#include "../../lib/CStack.h" + +#define LOGL(text) print(text) +#define LOGFL(text, formattingEl) print(boost::str(boost::format(text) % formattingEl)) + +enum class SpellTypes +{ + ADVENTURE, BATTLE, OTHER +}; + +SpellTypes spellType(const CSpell * spell) +{ + if(!spell->isCombat() || spell->isCreatureAbility()) + return SpellTypes::OTHER; + + if(spell->isOffensive() || spell->hasEffects() || spell->hasBattleEffects()) + return SpellTypes::BATTLE; + + return SpellTypes::OTHER; +} + +std::vector BattleEvaluator::getBrokenWallMoatHexes() const +{ + std::vector result; + + for(EWallPart wallPart : { EWallPart::BOTTOM_WALL, EWallPart::BELOW_GATE, EWallPart::OVER_GATE, EWallPart::UPPER_WALL }) + { + auto state = cb->battleGetWallState(wallPart); + + if(state != EWallState::DESTROYED) + continue; + + auto wallHex = cb->wallPartToBattleHex((EWallPart)wallPart); + auto moatHex = wallHex.cloneInDirection(BattleHex::LEFT); + + result.push_back(moatHex); + } + + return result; +} + +std::optional BattleEvaluator::findBestCreatureSpell(const CStack *stack) +{ + //TODO: faerie dragon type spell should be selected by server + SpellID creatureSpellToCast = cb->battleGetRandomStackSpell(CRandomGenerator::getDefault(), stack, CBattleInfoCallback::RANDOM_AIMED); + if(stack->hasBonusOfType(BonusType::SPELLCASTER) && stack->canCast() && creatureSpellToCast != SpellID::NONE) + { + const CSpell * spell = creatureSpellToCast.toSpell(); + + if(spell->canBeCast(getCbc().get(), spells::Mode::CREATURE_ACTIVE, stack)) + { + std::vector possibleCasts; + spells::BattleCast temp(getCbc().get(), stack, spells::Mode::CREATURE_ACTIVE, spell); + for(auto & target : temp.findPotentialTargets()) + { + PossibleSpellcast ps; + ps.dest = target; + ps.spell = spell; + evaluateCreatureSpellcast(stack, ps); + possibleCasts.push_back(ps); + } + + std::sort(possibleCasts.begin(), possibleCasts.end(), [&](const PossibleSpellcast & lhs, const PossibleSpellcast & rhs) { return lhs.value > rhs.value; }); + if(!possibleCasts.empty() && possibleCasts.front().value > 0) + { + return possibleCasts.front(); + } + } + } + return std::nullopt; +} + +BattleAction BattleEvaluator::selectStackAction(const CStack * stack) +{ + //evaluate casting spell for spellcasting stack + std::optional bestSpellcast = findBestCreatureSpell(stack); + + auto moveTarget = scoreEvaluator.findMoveTowardsUnreachable(stack, *targets, damageCache, hb); + auto score = EvaluationResult::INEFFECTIVE_SCORE; + + if(targets->possibleAttacks.empty() && bestSpellcast.has_value()) + { + activeActionMade = true; + return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id); + } + + if(!targets->possibleAttacks.empty()) + { +#if BATTLE_TRACE_LEVEL>=1 + logAi->trace("Evaluating attack for %s", stack->getDescription()); +#endif + + auto evaluationResult = scoreEvaluator.findBestTarget(stack, *targets, damageCache, hb); + auto & bestAttack = evaluationResult.bestAttack; + + cachedAttack = bestAttack; + cachedScore = evaluationResult.score; + + //TODO: consider more complex spellcast evaluation, f.e. because "re-retaliation" during enemy move in same turn for melee attack etc. + if(bestSpellcast.has_value() && bestSpellcast->value > bestAttack.damageDiff()) + { + // return because spellcast value is damage dealt and score is dps reduce + activeActionMade = true; + return BattleAction::makeCreatureSpellcast(stack, bestSpellcast->dest, bestSpellcast->spell->id); + } + + if(evaluationResult.score > score) + { + score = evaluationResult.score; + + logAi->debug("BattleAI: %s -> %s x %d, from %d curpos %d dist %d speed %d: +%lld -%lld = %lld", + bestAttack.attackerState->unitType()->getJsonKey(), + bestAttack.affectedUnits[0]->unitType()->getJsonKey(), + (int)bestAttack.affectedUnits[0]->getCount(), + (int)bestAttack.from, + (int)bestAttack.attack.attacker->getPosition().hex, + bestAttack.attack.chargeDistance, + bestAttack.attack.attacker->speed(0, true), + bestAttack.defenderDamageReduce, + bestAttack.attackerDamageReduce, bestAttack.attackValue() + ); + + if (moveTarget.scorePerTurn <= score) + { + if(evaluationResult.wait) + { + return BattleAction::makeWait(stack); + } + else if(bestAttack.attack.shooting) + { + activeActionMade = true; + return BattleAction::makeShotAttack(stack, bestAttack.attack.defender); + } + else + { + if(bestAttack.collateralDamageReduce + && bestAttack.collateralDamageReduce >= bestAttack.defenderDamageReduce / 2 + && score < 0) + { + return BattleAction::makeDefend(stack); + } + else + { + activeActionMade = true; + return BattleAction::makeMeleeAttack(stack, bestAttack.attack.defender->getPosition(), bestAttack.from); + } + } + } + } + } + + //ThreatMap threatsToUs(stack); // These lines may be usefull but they are't used in the code. + if(moveTarget.scorePerTurn > score) + { + score = moveTarget.score; + cachedAttack = moveTarget.cachedAttack; + cachedScore = score; + + if(stack->waited()) + { + return goTowardsNearest(stack, moveTarget.positions); + } + else + { + return BattleAction::makeWait(stack); + } + } + + if(score <= EvaluationResult::INEFFECTIVE_SCORE + && !stack->hasBonusOfType(BonusType::FLYING) + && stack->unitSide() == BattleSide::ATTACKER + && cb->battleGetSiegeLevel() >= CGTownInstance::CITADEL) + { + auto brokenWallMoat = getBrokenWallMoatHexes(); + + if(brokenWallMoat.size()) + { + activeActionMade = true; + + if(stack->doubleWide() && vstd::contains(brokenWallMoat, stack->getPosition())) + return BattleAction::makeMove(stack, stack->getPosition().cloneInDirection(BattleHex::RIGHT)); + else + return goTowardsNearest(stack, brokenWallMoat); + } + } + + return BattleAction::makeDefend(stack); +} + +uint64_t timeElapsed(std::chrono::time_point start) +{ + auto end = std::chrono::high_resolution_clock::now(); + + return std::chrono::duration_cast(end - start).count(); +} + +BattleAction BattleEvaluator::goTowardsNearest(const CStack * stack, std::vector hexes) +{ + auto reachability = cb->getReachability(stack); + auto avHexes = cb->battleGetAvailableHexes(reachability, stack, false); + + if(!avHexes.size() || !hexes.size()) //we are blocked or dest is blocked + { + return BattleAction::makeDefend(stack); + } + + std::sort(hexes.begin(), hexes.end(), [&](BattleHex h1, BattleHex h2) -> bool + { + return reachability.distances[h1] < reachability.distances[h2]; + }); + + for(auto hex : hexes) + { + if(vstd::contains(avHexes, hex)) + { + return BattleAction::makeMove(stack, hex); + } + + if(stack->coversPos(hex)) + { + logAi->warn("Warning: already standing on neighbouring tile!"); + //We shouldn't even be here... + return BattleAction::makeDefend(stack); + } + } + + BattleHex bestNeighbor = hexes.front(); + + if(reachability.distances[bestNeighbor] > GameConstants::BFIELD_SIZE) + { + return BattleAction::makeDefend(stack); + } + + scoreEvaluator.updateReachabilityMap(hb); + + if(stack->hasBonusOfType(BonusType::FLYING)) + { + std::set obstacleHexes; + + auto insertAffected = [](const CObstacleInstance & spellObst, std::set obstacleHexes) { + auto affectedHexes = spellObst.getAffectedTiles(); + obstacleHexes.insert(affectedHexes.cbegin(), affectedHexes.cend()); + }; + + const auto & obstacles = hb->battleGetAllObstacles(); + + for (const auto & obst: obstacles) { + + if(obst->triggersEffects()) + { + auto triggerAbility = VLC->spells()->getById(obst->getTrigger()); + auto triggerIsNegative = triggerAbility->isNegative() || triggerAbility->isDamage(); + + if(triggerIsNegative) + insertAffected(*obst, obstacleHexes); + } + } + // Flying stack doesn't go hex by hex, so we can't backtrack using predecessors. + // We just check all available hexes and pick the one closest to the target. + auto nearestAvailableHex = vstd::minElementByFun(avHexes, [&](BattleHex hex) -> int + { + const int NEGATIVE_OBSTACLE_PENALTY = 100; // avoid landing on negative obstacle (moat, fire wall, etc) + const int BLOCKED_STACK_PENALTY = 100; // avoid landing on moat + + auto distance = BattleHex::getDistance(bestNeighbor, hex); + + if(vstd::contains(obstacleHexes, hex)) + distance += NEGATIVE_OBSTACLE_PENALTY; + + return scoreEvaluator.checkPositionBlocksOurStacks(*hb, stack, hex) ? BLOCKED_STACK_PENALTY + distance : distance; + }); + + return BattleAction::makeMove(stack, *nearestAvailableHex); + } + else + { + BattleHex currentDest = bestNeighbor; + while(1) + { + if(!currentDest.isValid()) + { + return BattleAction::makeDefend(stack); + } + + if(vstd::contains(avHexes, currentDest) + && !scoreEvaluator.checkPositionBlocksOurStacks(*hb, stack, currentDest)) + return BattleAction::makeMove(stack, currentDest); + + currentDest = reachability.predecessors[currentDest]; + } + } +} + +bool BattleEvaluator::canCastSpell() +{ + auto hero = cb->battleGetMyHero(); + if(!hero) + return false; + + return cb->battleCanCastSpell(hero, spells::Mode::HERO) == ESpellCastProblem::OK; +} + +bool BattleEvaluator::attemptCastingSpell(const CStack * activeStack) +{ + auto hero = cb->battleGetMyHero(); + if(!hero) + return false; + + LOGL("Casting spells sounds like fun. Let's see..."); + //Get all spells we can cast + std::vector possibleSpells; + vstd::copy_if(VLC->spellh->objects, std::back_inserter(possibleSpells), [hero, this](const CSpell *s) -> bool + { + return s->canBeCast(cb.get(), spells::Mode::HERO, hero); + }); + LOGFL("I can cast %d spells.", possibleSpells.size()); + + vstd::erase_if(possibleSpells, [](const CSpell *s) + { + return spellType(s) != SpellTypes::BATTLE || s->getTargetType() == spells::AimType::LOCATION; + }); + + LOGFL("I know how %d of them works.", possibleSpells.size()); + + //Get possible spell-target pairs + std::vector possibleCasts; + for(auto spell : possibleSpells) + { + spells::BattleCast temp(cb.get(), hero, spells::Mode::HERO, spell); + + if(spell->getTargetType() == spells::AimType::LOCATION) + continue; + + const bool FAST = true; + + for(auto & target : temp.findPotentialTargets(FAST)) + { + PossibleSpellcast ps; + ps.dest = target; + ps.spell = spell; + possibleCasts.push_back(ps); + } + } + LOGFL("Found %d spell-target combinations.", possibleCasts.size()); + if(possibleCasts.empty()) + return false; + + using ValueMap = PossibleSpellcast::ValueMap; + + auto evaluateQueue = [&](ValueMap & values, const std::vector & queue, std::shared_ptr state, size_t minTurnSpan, bool * enemyHadTurnOut) -> bool + { + bool firstRound = true; + bool enemyHadTurn = false; + size_t ourTurnSpan = 0; + + bool stop = false; + + for(auto & round : queue) + { + if(!firstRound) + state->nextRound(0);//todo: set actual value? + for(auto unit : round) + { + if(!vstd::contains(values, unit->unitId())) + values[unit->unitId()] = 0; + + if(!unit->alive()) + continue; + + if(state->battleGetOwner(unit) != playerID) + { + enemyHadTurn = true; + + if(!firstRound || state->battleCastSpells(unit->unitSide()) == 0) + { + //enemy could counter our spell at this point + //anyway, we do not know what enemy will do + //just stop evaluation + stop = true; + break; + } + } + else if(!enemyHadTurn) + { + ourTurnSpan++; + } + + state->nextTurn(unit->unitId()); + + PotentialTargets pt(unit, damageCache, state); + + if(!pt.possibleAttacks.empty()) + { + AttackPossibility ap = pt.bestAction(); + + auto swb = state->getForUpdate(unit->unitId()); + *swb = *ap.attackerState; + + if(ap.defenderDamageReduce > 0) + swb->removeUnitBonus(Bonus::UntilAttack); + if(ap.attackerDamageReduce > 0) + swb->removeUnitBonus(Bonus::UntilBeingAttacked); + + for(auto affected : ap.affectedUnits) + { + swb = state->getForUpdate(affected->unitId()); + *swb = *affected; + + if(ap.defenderDamageReduce > 0) + swb->removeUnitBonus(Bonus::UntilBeingAttacked); + if(ap.attackerDamageReduce > 0 && ap.attack.defender->unitId() == affected->unitId()) + swb->removeUnitBonus(Bonus::UntilAttack); + } + } + + auto bav = pt.bestActionValue(); + + //best action is from effective owner`s point if view, we need to convert to our point if view + if(state->battleGetOwner(unit) != playerID) + bav = -bav; + values[unit->unitId()] += bav; + } + + firstRound = false; + + if(stop) + break; + } + + if(enemyHadTurnOut) + *enemyHadTurnOut = enemyHadTurn; + + return ourTurnSpan >= minTurnSpan; + }; + + ValueMap valueOfStack; + ValueMap healthOfStack; + + TStacks all = cb->battleGetAllStacks(false); + + size_t ourRemainingTurns = 0; + + for(auto unit : all) + { + healthOfStack[unit->unitId()] = unit->getAvailableHealth(); + valueOfStack[unit->unitId()] = 0; + + if(cb->battleGetOwner(unit) == playerID && unit->canMove() && !unit->moved()) + ourRemainingTurns++; + } + + LOGFL("I have %d turns left in this round", ourRemainingTurns); + + const bool castNow = ourRemainingTurns <= 1; + + if(castNow) + print("I should try to cast a spell now"); + else + print("I could wait better moment to cast a spell"); + + auto amount = all.size(); + + std::vector turnOrder; + + cb->battleGetTurnOrder(turnOrder, amount, 2); //no more than 1 turn after current, each unit at least once + + { + bool enemyHadTurn = false; + + auto state = std::make_shared(env.get(), cb); + + evaluateQueue(valueOfStack, turnOrder, state, 0, &enemyHadTurn); + + if(!enemyHadTurn) + { + auto battleIsFinishedOpt = state->battleIsFinished(); + + if(battleIsFinishedOpt) + { + print("No need to cast a spell. Battle will finish soon."); + return false; + } + } + } + + CStopWatch timer; + + tbb::parallel_for(tbb::blocked_range(0, possibleCasts.size()), [&](const tbb::blocked_range & r) + { + for(auto i = r.begin(); i != r.end(); i++) + { + auto & ps = possibleCasts[i]; + auto state = std::make_shared(env.get(), cb); + + spells::BattleCast cast(state.get(), hero, spells::Mode::HERO, ps.spell); + cast.castEval(state->getServerCallback(), ps.dest); + + auto allUnits = state->battleGetUnitsIf([](const battle::Unit * u) -> bool { return true; }); + + auto needFullEval = vstd::contains_if(allUnits, [&](const battle::Unit * u) -> bool + { + auto original = cb->battleGetUnitByID(u->unitId()); + return !original || u->speed() != original->speed(); + }); + + DamageCache innerCache(&damageCache); + innerCache.buildDamageCache(state, side); + + if(needFullEval || !cachedAttack) + { + PotentialTargets innerTargets(activeStack, damageCache, state); + BattleExchangeEvaluator innerEvaluator(state, env, strengthRatio); + + if(!innerTargets.possibleAttacks.empty()) + { + innerEvaluator.updateReachabilityMap(state); + + auto newStackAction = innerEvaluator.findBestTarget(activeStack, innerTargets, innerCache, state); + + ps.value = newStackAction.score; + } + else + { + ps.value = 0; + } + } + else + { + ps.value = scoreEvaluator.calculateExchange(*cachedAttack, *targets, innerCache, state); + } + + for(auto unit : allUnits) + { + auto newHealth = unit->getAvailableHealth(); + auto oldHealth = healthOfStack[unit->unitId()]; + + if(oldHealth != newHealth) + { + auto damage = std::abs(oldHealth - newHealth); + auto originalDefender = cb->battleGetUnitByID(unit->unitId()); + + auto dpsReduce = AttackPossibility::calculateDamageReduce( + nullptr, + originalDefender && originalDefender->alive() ? originalDefender : unit, + damage, + innerCache, + state); + + auto ourUnit = unit->unitSide() == side ? 1 : -1; + auto goodEffect = newHealth > oldHealth ? 1 : -1; + + if(ourUnit * goodEffect == 1) + { + if(ourUnit && goodEffect && (unit->isClone() || unit->isGhost() || !unit->unitSlot().validSlot())) + continue; + + ps.value += dpsReduce * scoreEvaluator.getPositiveEffectMultiplier(); + } + else + ps.value -= dpsReduce * scoreEvaluator.getNegativeEffectMultiplier(); + } + } + } + }); + + LOGFL("Evaluation took %d ms", timer.getDiff()); + + auto pscValue = [](const PossibleSpellcast &ps) -> int64_t + { + return ps.value; + }; + auto castToPerform = *vstd::maxElementByFun(possibleCasts, pscValue); + + if(castToPerform.value > cachedScore) + { + LOGFL("Best spell is %s (value %d). Will cast.", castToPerform.spell->getNameTranslated() % castToPerform.value); + BattleAction spellcast; + spellcast.actionType = EActionType::HERO_SPELL; + spellcast.spell = castToPerform.spell->id; + spellcast.setTarget(castToPerform.dest); + spellcast.side = side; + spellcast.stackNumber = (!side) ? -1 : -2; + cb->battleMakeSpellAction(spellcast); + activeActionMade = true; + + return true; + } + + LOGFL("Best spell is %s. But it is actually useless (value %d).", castToPerform.spell->getNameTranslated() % castToPerform.value); + + return false; +} + +//Below method works only for offensive spells +void BattleEvaluator::evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps) +{ + using ValueMap = PossibleSpellcast::ValueMap; + + RNGStub rngStub; + HypotheticBattle state(env.get(), cb); + TStacks all = cb->battleGetAllStacks(false); + + ValueMap healthOfStack; + ValueMap newHealthOfStack; + + for(auto unit : all) + { + healthOfStack[unit->unitId()] = unit->getAvailableHealth(); + } + + spells::BattleCast cast(&state, stack, spells::Mode::CREATURE_ACTIVE, ps.spell); + cast.castEval(state.getServerCallback(), ps.dest); + + for(auto unit : all) + { + auto unitId = unit->unitId(); + auto localUnit = state.battleGetUnitByID(unitId); + newHealthOfStack[unitId] = localUnit->getAvailableHealth(); + } + + int64_t totalGain = 0; + + for(auto unit : all) + { + auto unitId = unit->unitId(); + auto localUnit = state.battleGetUnitByID(unitId); + + auto healthDiff = newHealthOfStack[unitId] - healthOfStack[unitId]; + + if(localUnit->unitOwner() != getCbc()->getPlayerID()) + healthDiff = -healthDiff; + + if(healthDiff < 0) + { + ps.value = -1; + return; //do not damage own units at all + } + + totalGain += healthDiff; + } + + ps.value = totalGain; +} + +void BattleEvaluator::print(const std::string & text) const +{ + logAi->trace("%s Battle AI[%p]: %s", playerID.getStr(), this, text); +} + + + diff --git a/AI/BattleAI/BattleEvaluator.h b/AI/BattleAI/BattleEvaluator.h new file mode 100644 index 000000000..b3f61091c --- /dev/null +++ b/AI/BattleAI/BattleEvaluator.h @@ -0,0 +1,80 @@ +/* + * BattleEvaluator.h, part of VCMI engine + * + * Authors: listed in file AUTHORS in main folder + * + * License: GNU General Public License v2.0 or later + * Full text of license available in license.txt file, in main folder + * + */ +#pragma once +#include "../../lib/AI_Base.h" +#include "../../lib/battle/ReachabilityInfo.h" +#include "PossibleSpellcast.h" +#include "PotentialTargets.h" +#include "BattleExchangeVariant.h" + +VCMI_LIB_NAMESPACE_BEGIN + +class CSpell; + +VCMI_LIB_NAMESPACE_END + +class EnemyInfo; + +class BattleEvaluator +{ + std::unique_ptr targets; + std::shared_ptr hb; + BattleExchangeEvaluator scoreEvaluator; + std::shared_ptr cb; + std::shared_ptr env; + bool activeActionMade = false; + std::optional cachedAttack; + PlayerColor playerID; + int side; + int64_t cachedScore; + DamageCache damageCache; + float strengthRatio; + +public: + BattleAction selectStackAction(const CStack * stack); + bool attemptCastingSpell(const CStack * stack); + bool canCastSpell(); + std::optional findBestCreatureSpell(const CStack * stack); + BattleAction goTowardsNearest(const CStack * stack, std::vector hexes); + std::vector getBrokenWallMoatHexes() const; + void evaluateCreatureSpellcast(const CStack * stack, PossibleSpellcast & ps); //for offensive damaging spells only + void print(const std::string & text) const; + + BattleEvaluator( + std::shared_ptr env, + std::shared_ptr cb, + const battle::Unit * activeStack, + PlayerColor playerID, + int side, + float strengthRatio) + :scoreEvaluator(cb, env, strengthRatio), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), strengthRatio(strengthRatio) + { + hb = std::make_shared(env.get(), cb); + damageCache.buildDamageCache(hb, side); + + targets = std::make_unique(activeStack, damageCache, hb); + cachedScore = EvaluationResult::INEFFECTIVE_SCORE; + } + + BattleEvaluator( + std::shared_ptr env, + std::shared_ptr cb, + std::shared_ptr hb, + DamageCache & damageCache, + const battle::Unit * activeStack, + PlayerColor playerID, + int side, + float strengthRatio) + :scoreEvaluator(cb, env, strengthRatio), cachedAttack(), playerID(playerID), side(side), env(env), cb(cb), hb(hb), damageCache(damageCache), strengthRatio(strengthRatio) + { + targets = std::make_unique(activeStack, damageCache, hb); + cachedScore = EvaluationResult::INEFFECTIVE_SCORE; + } +}; diff --git a/AI/BattleAI/BattleExchangeVariant.cpp b/AI/BattleAI/BattleExchangeVariant.cpp index a53bb646b..b392af13a 100644 --- a/AI/BattleAI/BattleExchangeVariant.cpp +++ b/AI/BattleAI/BattleExchangeVariant.cpp @@ -41,7 +41,7 @@ int64_t BattleExchangeVariant::trackAttack(const AttackPossibility & ap, Hypothe unitToUpdate->movedThisRound = affectedUnit->movedThisRound; } - auto attackValue = ap.attackValue(); + auto attackValue = ap.damageDiff(positiveEffectMultiplier, negativeEffectMultiplier); dpsScore += attackValue; @@ -97,11 +97,11 @@ int64_t BattleExchangeVariant::trackAttack( if(isOurAttack) { - dpsScore += defenderDamageReduce; + dpsScore += defenderDamageReduce * positiveEffectMultiplier; attackerValue[attacker->unitId()].value += defenderDamageReduce; } else - dpsScore -= defenderDamageReduce; + dpsScore -= defenderDamageReduce * negativeEffectMultiplier; defender->damage(attackDamage); attacker->afterAttack(shooting, false); @@ -125,12 +125,12 @@ int64_t BattleExchangeVariant::trackAttack( if(isOurAttack) { - dpsScore -= attackerDamageReduce; + dpsScore -= attackerDamageReduce * negativeEffectMultiplier; attackerValue[attacker->unitId()].isRetalitated = true; } else { - dpsScore += attackerDamageReduce; + dpsScore += attackerDamageReduce * positiveEffectMultiplier; attackerValue[defender->unitId()].value += attackerDamageReduce; } @@ -206,7 +206,7 @@ MoveTarget BattleExchangeEvaluator::findMoveTowardsUnreachable( std::shared_ptr hb) { MoveTarget result; - BattleExchangeVariant ev; + BattleExchangeVariant ev(getPositiveEffectMultiplier(), getNegativeEffectMultiplier()); if(targets.unreachableEnemies.empty()) return result; @@ -353,6 +353,11 @@ std::vector BattleExchangeEvaluator::getExchangeUnits( } } + vstd::erase_if(exchangeUnits, [&](const battle::Unit * u) -> bool + { + return !hb->battleGetUnitByID(u->unitId())->alive(); + }); + return exchangeUnits; } @@ -376,7 +381,8 @@ int64_t BattleExchangeEvaluator::calculateExchange( std::vector ourStacks; std::vector enemyStacks; - enemyStacks.push_back(ap.attack.defender); + if(hb->battleGetUnitByID(ap.attack.defender->unitId())->alive()) + enemyStacks.push_back(ap.attack.defender); std::vector exchangeUnits = getExchangeUnits(ap, targets, hb); @@ -386,14 +392,7 @@ int64_t BattleExchangeEvaluator::calculateExchange( } auto exchangeBattle = std::make_shared(env.get(), hb); - BattleExchangeVariant v; - auto melleeAttackers = ourStacks; - - vstd::removeDuplicates(melleeAttackers); - vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool - { - return !cb->battleCanShoot(u); - }); + BattleExchangeVariant v(getPositiveEffectMultiplier(), getNegativeEffectMultiplier()); for(auto unit : exchangeUnits) { @@ -403,12 +402,20 @@ int64_t BattleExchangeEvaluator::calculateExchange( bool isOur = exchangeBattle->battleMatchOwner(ap.attack.attacker, unit, true); auto & attackerQueue = isOur ? ourStacks : enemyStacks; - if(!vstd::contains(attackerQueue, unit)) + if(exchangeBattle->getForUpdate(unit->unitId())->alive() && !vstd::contains(attackerQueue, unit)) { attackerQueue.push_back(unit); } } + auto melleeAttackers = ourStacks; + + vstd::removeDuplicates(melleeAttackers); + vstd::erase_if(melleeAttackers, [&](const battle::Unit * u) -> bool + { + return !cb->battleCanShoot(u); + }); + bool canUseAp = true; for(auto activeUnit : exchangeUnits) @@ -430,7 +437,7 @@ int64_t BattleExchangeEvaluator::calculateExchange( auto targetUnit = ap.attack.defender; - if(!isOur || !exchangeBattle->getForUpdate(targetUnit->unitId())->alive()) + if(!isOur || !exchangeBattle->battleGetUnitByID(targetUnit->unitId())->alive()) { auto estimateAttack = [&](const battle::Unit * u) -> int64_t { @@ -459,7 +466,10 @@ int64_t BattleExchangeEvaluator::calculateExchange( { auto reachable = exchangeBattle->battleGetUnitsIf([&](const battle::Unit * u) -> bool { - if(!u->alive() || u->unitSide() == attacker->unitSide()) + if(u->unitSide() == attacker->unitSide()) + return false; + + if(!exchangeBattle->getForUpdate(u->unitId())->alive()) return false; return vstd::contains_if(reachabilityMap[u->getPosition()], [&](const battle::Unit * other) -> bool @@ -506,12 +516,12 @@ int64_t BattleExchangeEvaluator::calculateExchange( vstd::erase_if(attackerQueue, [&](const battle::Unit * u) -> bool { - return !exchangeBattle->getForUpdate(u->unitId())->alive(); + return !exchangeBattle->battleGetUnitByID(u->unitId())->alive(); }); vstd::erase_if(oppositeQueue, [&](const battle::Unit * u) -> bool { - return !exchangeBattle->getForUpdate(u->unitId())->alive(); + return !exchangeBattle->battleGetUnitByID(u->unitId())->alive(); }); } diff --git a/AI/BattleAI/BattleExchangeVariant.h b/AI/BattleAI/BattleExchangeVariant.h index 3d95dd912..9837924a6 100644 --- a/AI/BattleAI/BattleExchangeVariant.h +++ b/AI/BattleAI/BattleExchangeVariant.h @@ -59,7 +59,8 @@ struct EvaluationResult class BattleExchangeVariant { public: - BattleExchangeVariant(): dpsScore(0) {} + BattleExchangeVariant(float positiveEffectMultiplier, float negativeEffectMultiplier) + : dpsScore(0), positiveEffectMultiplier(positiveEffectMultiplier), negativeEffectMultiplier(negativeEffectMultiplier) {} int64_t trackAttack(const AttackPossibility & ap, HypotheticBattle & state); @@ -80,6 +81,8 @@ public: std::map & reachabilityMap); private: + float positiveEffectMultiplier; + float negativeEffectMultiplier; int64_t dpsScore; std::map attackerValue; }; @@ -91,9 +94,15 @@ private: std::shared_ptr env; std::map> reachabilityMap; std::vector turnOrder; + float negativeEffectMultiplier; public: - BattleExchangeEvaluator(std::shared_ptr cb, std::shared_ptr env): cb(cb), env(env) {} + BattleExchangeEvaluator( + std::shared_ptr cb, + std::shared_ptr env, + float strengthRatio): cb(cb), env(env) { + negativeEffectMultiplier = strengthRatio; + } EvaluationResult findBestTarget( const battle::Unit * activeStack, @@ -118,4 +127,7 @@ public: std::shared_ptr hb); std::vector getAdjacentUnits(const battle::Unit * unit); + + float getPositiveEffectMultiplier() { return 1; } + float getNegativeEffectMultiplier() { return negativeEffectMultiplier; } }; \ No newline at end of file diff --git a/AI/BattleAI/CMakeLists.txt b/AI/BattleAI/CMakeLists.txt index 1850e24f1..335c92f5c 100644 --- a/AI/BattleAI/CMakeLists.txt +++ b/AI/BattleAI/CMakeLists.txt @@ -1,6 +1,7 @@ set(battleAI_SRCS AttackPossibility.cpp BattleAI.cpp + BattleEvaluator.cpp common.cpp EnemyInfo.cpp PossibleSpellcast.cpp @@ -15,6 +16,7 @@ set(battleAI_HEADERS AttackPossibility.h BattleAI.h + BattleEvaluator.h common.h EnemyInfo.h PotentialTargets.h