diff --git a/utils/priority_queue.lua b/utils/priority_queue.lua index bccd2358..85f07335 100644 --- a/utils/priority_queue.lua +++ b/utils/priority_queue.lua @@ -1,20 +1,45 @@ +local Debug = require 'utils.debug' +local is_closure = Debug.is_closure +local floor = math.floor + local PriorityQueue = {} -function PriorityQueue.new() - return {} -end - -local function default_comp(a, b) +local function default_comparator(a, b) return a < b end -local function HeapifyFromEndToStart(queue, comp) - comp = comp or default_comp - local pos = #queue +--- Min heap implementation of a priority queue. Smaller elements, as determined by the comparator, +-- have a higher priority. +-- @param comparator the comparator function used to compare elements +-- @usage +-- local PriorityQueue = require 'utils.priority_queue' +-- +-- local queue = PriorityQueue.new() +-- PriorityQueue.push(queue, 4) +-- PriorityQueue.push(queue, 7) +-- PriorityQueue.push(queue, 2) +-- +-- game.print(PriorityQueue.pop(queue)) -- 2 +-- game.print(PriorityQueue.pop(queue)) -- 4 +-- game.print(PriorityQueue.pop(queue)) -- 7 +function PriorityQueue.new(comparator) + if comparator == nil then + comparator = default_comparator + elseif is_closure(comparator) then + error('comparator cannot be a closure.', 2) + end + + return {_comparator = comparator} +end + +local function heapify_from_end_to_start(self) + local comparator = self._comparator + local pos = #self while pos > 1 do - local parent = bit32.rshift(pos, 1) -- integer division by 2 - if comp(queue[pos], queue[parent]) then - queue[pos], queue[parent] = queue[parent], queue[pos] + local parent = floor(pos / 2) + local a, b = self[pos], self[parent] + if comparator(a, b) then + self[pos], self[parent] = b, a pos = parent else break @@ -22,25 +47,26 @@ local function HeapifyFromEndToStart(queue, comp) end end -local function HeapifyFromStartToEnd(queue, comp) - comp = comp or default_comp +local function heapify_from_start_to_end(self) + local comparator = self._comparator local parent = 1 local smallest = 1 + local count = #self while true do local child = parent * 2 - if child > #queue then + if child > count then break end - if comp(queue[child], queue[parent]) then + if comparator(self[child], self[parent]) then smallest = child end child = child + 1 - if child <= #queue and comp(queue[child], queue[smallest]) then + if child <= count and comparator(self[child], self[smallest]) then smallest = child end if parent ~= smallest then - queue[parent], queue[smallest] = queue[smallest], queue[parent] + self[parent], self[smallest] = self[smallest], self[parent] parent = smallest else break @@ -48,27 +74,31 @@ local function HeapifyFromStartToEnd(queue, comp) end end -function PriorityQueue.size(queue) - return #queue +--- Returns the number of the number of elements in the priority queue. +function PriorityQueue.size(self) + return #self end -function PriorityQueue.push(queue, element, comp) - table.insert(queue, element) - HeapifyFromEndToStart(queue, comp) +-- Inserts an element into the priority queue. +function PriorityQueue.push(self, element) + self[#self + 1] = element + heapify_from_end_to_start(self) end -function PriorityQueue.pop(queue, comp) - local element = queue[1] +-- Removes and returns the highest priority element from the priority queue. +function PriorityQueue.pop(self) + local element = self[1] - queue[1] = queue[#queue] - queue[#queue] = nil - HeapifyFromStartToEnd(queue, comp) + self[1] = self[#self] + self[#self] = nil + heapify_from_start_to_end(self) return element end -function PriorityQueue.peek(queue) - return queue[1] +-- Returns, without removing, the highest priority element from the priority queue. +function PriorityQueue.peek(self) + return self[1] end return PriorityQueue