mirror of
				https://github.com/vcmi/vcmi.git
				synced 2025-10-31 00:07:39 +02:00 
			
		
		
		
	Try to implement lazy evaluation for reachability map
This commit is contained in:
		| @@ -519,14 +519,14 @@ ReachabilityData BattleExchangeEvaluator::getExchangeUnits( | ||||
| 	 | ||||
| 	for(auto hex : hexes) | ||||
| 	{ | ||||
| 		vstd::concatenate(allReachableUnits, turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex)); | ||||
| 		vstd::concatenate(allReachableUnits, getOneTurnReachableUnits(turn, hex)); | ||||
| 	} | ||||
|  | ||||
| 	if(!ap.attack.attacker->isTurret()) | ||||
| 	{ | ||||
| 		for(auto hex : ap.attack.attacker->getHexes()) | ||||
| 		{ | ||||
| 			auto unitsReachingAttacker = turn == 0 ? reachabilityMap.at(hex.toInt()) : getOneTurnReachableUnits(turn, hex); | ||||
| 			auto unitsReachingAttacker = getOneTurnReachableUnits(turn, hex); | ||||
| 			for(auto unit : unitsReachingAttacker) | ||||
| 			{ | ||||
| 				if(unit->unitSide() != ap.attack.attacker->unitSide()) | ||||
| @@ -800,7 +800,9 @@ BattleScore BattleExchangeEvaluator::calculateExchange( | ||||
| 							if(!u->getPosition().isValid()) | ||||
| 								return false; // e.g. tower shooters | ||||
|  | ||||
| 							return vstd::contains_if(reachabilityMap.at(u->getPosition().toInt()), [&attacker](const battle::Unit * other) -> bool | ||||
| 							const auto & reachableUnits = getOneTurnReachableUnits(0, u->getPosition()); | ||||
|  | ||||
| 							return vstd::contains_if(reachableUnits, [&attacker](const battle::Unit * other) -> bool | ||||
| 								{ | ||||
| 									return attacker->unitId() == other->unitId(); | ||||
| 								}); | ||||
| @@ -887,7 +889,7 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap) | ||||
| { | ||||
| 	for(auto pos : ap.attack.attacker->getSurroundingHexes()) | ||||
| 	{ | ||||
| 		for(auto u : reachabilityMap.at(pos.toInt())) | ||||
| 		for(auto u : getOneTurnReachableUnits(0, pos)) | ||||
| 		{ | ||||
| 			if(u->unitSide() != ap.attack.attacker->unitSide()) | ||||
| 			{ | ||||
| @@ -899,6 +901,22 @@ bool BattleExchangeEvaluator::canBeHitThisTurn(const AttackPossibility & ap) | ||||
| 	return false; | ||||
| } | ||||
|  | ||||
| void ReachabilityMapCache::update(const std::vector<battle::Units> & turnOrder, std::shared_ptr<HypotheticBattle> hb) | ||||
| { | ||||
| 	for(auto turn : turnOrder) | ||||
| 	{ | ||||
| 		for(auto u : turn) | ||||
| 		{ | ||||
| 			if(!vstd::contains(unitReachabilityMap, u->unitId())) | ||||
| 			{ | ||||
| 				unitReachabilityMap[u->unitId()] = hb->getReachability(u); | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	hexReachabilityPerTurn.clear(); | ||||
| } | ||||
|  | ||||
| void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb) | ||||
| { | ||||
| 	const int TURN_DEPTH = 2; | ||||
| @@ -906,28 +924,25 @@ void BattleExchangeEvaluator::updateReachabilityMap(std::shared_ptr<HypotheticBa | ||||
| 	turnOrder.clear(); | ||||
|  | ||||
| 	hb->battleGetTurnOrder(turnOrder, std::numeric_limits<int>::max(), TURN_DEPTH); | ||||
|  | ||||
| 	for(auto turn : turnOrder) | ||||
| 	{ | ||||
| 		for(auto u : turn) | ||||
| 		{ | ||||
| 			if(!vstd::contains(reachabilityCache, u->unitId())) | ||||
| 			{ | ||||
| 				reachabilityCache[u->unitId()] = hb->getReachability(u); | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	tbb::parallel_for(tbb::blocked_range<size_t>(0, reachabilityMap.size()), [&](const tbb::blocked_range<size_t> & r) | ||||
| 	{ | ||||
| 		for(auto i = r.begin(); i != r.end(); i++) | ||||
| 			reachabilityMap[i] = getOneTurnReachableUnits(0, BattleHex(i)); | ||||
| 	}); | ||||
| 	reachabilityMap.update(turnOrder, hb); | ||||
| } | ||||
|  | ||||
| std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const | ||||
| const battle::Units & ReachabilityMapCache::getOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex) | ||||
| { | ||||
| 	std::vector<const battle::Unit *> result; | ||||
| 	auto & turnData = hexReachabilityPerTurn[turn]; | ||||
|  | ||||
| 	if (!turnData.isValid[hex.toInt()]) | ||||
| 	{ | ||||
| 		turnData.hexes[hex.toInt()] = computeOneTurnReachableUnits(cb, env, turnOrder, turn, hex); | ||||
| 		turnData.isValid.set(hex.toInt()); | ||||
| 	} | ||||
|  | ||||
| 	return turnData.hexes[hex.toInt()]; | ||||
| } | ||||
|  | ||||
| battle::Units ReachabilityMapCache::computeOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex) | ||||
| { | ||||
| 	battle::Units result; | ||||
|  | ||||
| 	for(int i = 0; i < turnOrder.size(); i++, turn++) | ||||
| 	{ | ||||
| @@ -949,10 +964,10 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn | ||||
| 			auto unitSpeed = unit->getMovementRange(turn); | ||||
| 			auto radius = unitSpeed * (turn + 1); | ||||
|  | ||||
| 			auto reachabilityIter = reachabilityCache.find(unit->unitId()); | ||||
| 			assert(reachabilityIter != reachabilityCache.end()); // missing updateReachabilityMap call? | ||||
| 			auto reachabilityIter = unitReachabilityMap.find(unit->unitId()); | ||||
| 			assert(reachabilityIter != unitReachabilityMap.end()); // missing updateReachabilityMap call? | ||||
|  | ||||
| 			ReachabilityInfo unitReachability = reachabilityIter != reachabilityCache.end() ? reachabilityIter->second : turnBattle.getReachability(unit); | ||||
| 			ReachabilityInfo unitReachability = reachabilityIter != unitReachabilityMap.end() ? reachabilityIter->second : turnBattle.getReachability(unit); | ||||
|  | ||||
| 			bool reachable = unitReachability.distances.at(hex.toInt()) <= radius; | ||||
|  | ||||
| @@ -981,6 +996,11 @@ std::vector<const battle::Unit *> BattleExchangeEvaluator::getOneTurnReachableUn | ||||
| 	return result; | ||||
| } | ||||
|  | ||||
| const battle::Units & BattleExchangeEvaluator::getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const | ||||
| { | ||||
| 	return reachabilityMap.getOneTurnReachableUnits(cb, env, turnOrder, turn, hex); | ||||
| } | ||||
|  | ||||
| // avoid blocking path for stronger stack by weaker stack | ||||
| bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb, const battle::Unit * activeUnit, BattleHex position) | ||||
| { | ||||
| @@ -1032,9 +1052,11 @@ bool BattleExchangeEvaluator::checkPositionBlocksOurStacks(HypotheticBattle & hb | ||||
| 					} | ||||
| 				} | ||||
|  | ||||
| 				if(!reachable && std::count(reachabilityMap[hex.toInt()].begin(), reachabilityMap[hex.toInt()].end(), unit) > 1) | ||||
| 				if(!reachable) | ||||
| 				{ | ||||
| 					blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY); | ||||
| 					auto reachableUnits = getOneTurnReachableUnits(0, hex); | ||||
| 					if (std::count(reachableUnits.begin(), reachableUnits.end(), unit) > 1) | ||||
| 						blockingScore += ratio * (enemyUnit ? BLOCKING_OWN_ATTACK_PENALTY : BLOCKING_OWN_MOVE_PENALTY); | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|   | ||||
| @@ -123,13 +123,29 @@ struct ReachabilityData | ||||
| 	std::set<uint32_t> enemyUnitsReachingAttacker; | ||||
| }; | ||||
|  | ||||
| class ReachabilityMapCache | ||||
| { | ||||
| 	struct PerTurnData{ | ||||
| 		std::bitset<GameConstants::BFIELD_SIZE> isValid; | ||||
| 		std::array<battle::Units, GameConstants::BFIELD_SIZE> hexes; | ||||
| 	}; | ||||
|  | ||||
| 	std::map<uint32_t, ReachabilityInfo> unitReachabilityMap; // unit ID -> reachability | ||||
| 	std::map<uint32_t, PerTurnData> hexReachabilityPerTurn; | ||||
|  | ||||
| 	//const ReachabilityInfo & update(); | ||||
| 	battle::Units computeOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex); | ||||
| public: | ||||
| 	const battle::Units & getOneTurnReachableUnits(std::shared_ptr<CBattleInfoCallback> cb, std::shared_ptr<Environment> env, const std::vector<battle::Units> & turnOrder, uint8_t turn, BattleHex hex); | ||||
| 	void update(const std::vector<battle::Units> & turnOrder, std::shared_ptr<HypotheticBattle> hb); | ||||
| }; | ||||
|  | ||||
| class BattleExchangeEvaluator | ||||
| { | ||||
| private: | ||||
| 	std::shared_ptr<CBattleInfoCallback> cb; | ||||
| 	std::shared_ptr<Environment> env; | ||||
| 	std::map<uint32_t, ReachabilityInfo> reachabilityCache; | ||||
| 	std::array<std::vector<const battle::Unit *>, GameConstants::BFIELD_SIZE> reachabilityMap; | ||||
| 	mutable ReachabilityMapCache reachabilityMap; | ||||
| 	std::vector<battle::Units> turnOrder; | ||||
| 	float negativeEffectMultiplier; | ||||
| 	int simulationTurnsCount; | ||||
| @@ -169,7 +185,7 @@ public: | ||||
| 		DamageCache & damageCache, | ||||
| 		std::shared_ptr<HypotheticBattle> hb) const; | ||||
|  | ||||
| 	std::vector<const battle::Unit *> getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const; | ||||
| 	const battle::Units & getOneTurnReachableUnits(uint8_t turn, BattleHex hex) const; | ||||
| 	void updateReachabilityMap(std::shared_ptr<HypotheticBattle> hb); | ||||
|  | ||||
| 	ReachabilityData getExchangeUnits( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user