You've already forked mailcow-dockerized
							
							
				mirror of
				https://github.com/mailcow/mailcow-dockerized.git
				synced 2025-10-30 23:57:54 +02:00 
			
		
		
		
	
		
			
				
	
	
		
			675 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Lua
		
	
	
	
	
	
			
		
		
	
	
			675 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Lua
		
	
	
	
	
	
| --[[
 | |
| Copyright (c) 2011-2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
 | |
| Copyright (c) 2016-2017, Andrew Lewis <nerf@judo.za.org>
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| ]]--
 | |
| 
 | |
| if confighelp then
 | |
|   return
 | |
| end
 | |
| 
 | |
| -- A plugin that implements ratelimits using redis
 | |
| 
 | |
| local E = {}
 | |
| local N = 'ratelimit'
 | |
| local redis_params
 | |
| -- Senders that are considered as bounce
 | |
| local settings = {
 | |
|   bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' },
 | |
| -- Do not check ratelimits for these recipients
 | |
|   whitelisted_rcpts = { 'postmaster', 'mailer-daemon' },
 | |
|   prefix = 'RL',
 | |
|   ham_factor_rate = 1.01,
 | |
|   spam_factor_rate = 0.99,
 | |
|   ham_factor_burst = 1.02,
 | |
|   spam_factor_burst = 0.98,
 | |
|   max_rate_mult = 5,
 | |
|   max_bucket_mult = 10,
 | |
|   expire = 60 * 60 * 24 * 2, -- 2 days by default
 | |
|   limits = {},
 | |
|   allow_local = false,
 | |
| }
 | |
| 
 | |
| -- Checks bucket, updating it if needed
 | |
| -- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
 | |
| -- KEYS[2] - current time in milliseconds
 | |
| -- KEYS[3] - bucket leak rate (messages per millisecond)
 | |
| -- KEYS[4] - bucket burst
 | |
| -- KEYS[5] - expire for a bucket
 | |
| -- return 1 if message should be ratelimited and 0 if not
 | |
| -- Redis keys used:
 | |
| --   l - last hit
 | |
| --   b - current burst
 | |
| --   dr - current dynamic rate multiplier (*10000)
 | |
| --   db - current dynamic burst multiplier (*10000)
 | |
| local bucket_check_script = [[
 | |
|   local last = redis.call('HGET', KEYS[1], 'l')
 | |
|   local now = tonumber(KEYS[2])
 | |
|   local dynr, dynb = 0, 0
 | |
|   if not last then
 | |
|     -- New bucket
 | |
|     redis.call('HSET', KEYS[1], 'l', KEYS[2])
 | |
|     redis.call('HSET', KEYS[1], 'b', '0')
 | |
|     redis.call('HSET', KEYS[1], 'dr', '10000')
 | |
|     redis.call('HSET', KEYS[1], 'db', '10000')
 | |
|     redis.call('EXPIRE', KEYS[1], KEYS[5])
 | |
|     return {0, 0, 1, 1}
 | |
|   end
 | |
| 
 | |
|   last = tonumber(last)
 | |
|   local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
 | |
|   -- Perform leak
 | |
|   if burst > 0 then
 | |
|    if last < tonumber(KEYS[2]) then
 | |
|     local rate = tonumber(KEYS[3])
 | |
|     dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0
 | |
|     rate = rate * dynr
 | |
|     local leaked = ((now - last) * rate)
 | |
|     burst = burst - leaked
 | |
|     redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked))
 | |
|    end
 | |
|   else
 | |
|    burst = 0
 | |
|    redis.call('HSET', KEYS[1], 'b', '0')
 | |
|   end
 | |
| 
 | |
|   dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0
 | |
| 
 | |
|   if (burst + 1) * dynb > tonumber(KEYS[4]) then
 | |
|    return {1, tostring(burst), tostring(dynr), tostring(dynb)}
 | |
|   end
 | |
| 
 | |
|   return {0, tostring(burst), tostring(dynr), tostring(dynb)}
 | |
| ]]
 | |
| local bucket_check_id
 | |
| 
 | |
| 
 | |
| -- Updates a bucket
 | |
| -- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
 | |
| -- KEYS[2] - current time in milliseconds
 | |
| -- KEYS[3] - dynamic rate multiplier
 | |
| -- KEYS[4] - dynamic burst multiplier
 | |
| -- KEYS[5] - max dyn rate (min: 1/x)
 | |
| -- KEYS[6] - max burst rate (min: 1/x)
 | |
| -- KEYS[7] - expire for a bucket
 | |
| -- Redis keys used:
 | |
| --   l - last hit
 | |
| --   b - current burst
 | |
| --   dr - current dynamic rate multiplier
 | |
| --   db - current dynamic burst multiplier
 | |
| local bucket_update_script = [[
 | |
|   local last = redis.call('HGET', KEYS[1], 'l')
 | |
|   local now = tonumber(KEYS[2])
 | |
|   if not last then
 | |
|     -- New bucket
 | |
|     redis.call('HSET', KEYS[1], 'l', KEYS[2])
 | |
|     redis.call('HSET', KEYS[1], 'b', '1')
 | |
|     redis.call('HSET', KEYS[1], 'dr', '10000')
 | |
|     redis.call('HSET', KEYS[1], 'db', '10000')
 | |
|     redis.call('EXPIRE', KEYS[1], KEYS[7])
 | |
|     return {1, 1, 1}
 | |
|   end
 | |
| 
 | |
|   local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
 | |
|   local db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000
 | |
|   local dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000
 | |
| 
 | |
|   if dr < tonumber(KEYS[5]) and dr > 1.0 / tonumber(KEYS[5]) then
 | |
|     dr = dr * tonumber(KEYS[3])
 | |
|     redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
 | |
|   end
 | |
| 
 | |
|   if db < tonumber(KEYS[6]) and db > 1.0 / tonumber(KEYS[6]) then
 | |
|     db = db * tonumber(KEYS[4])
 | |
|     redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
 | |
|   end
 | |
| 
 | |
|   redis.call('HINCRBYFLOAT', KEYS[1], 'b', 1)
 | |
|   redis.call('HSET', KEYS[1], 'l', KEYS[2])
 | |
|   redis.call('EXPIRE', KEYS[1], KEYS[7])
 | |
| 
 | |
|   return {tostring(burst), tostring(dr), tostring(db)}
 | |
| ]]
 | |
| local bucket_update_id
 | |
| 
 | |
| -- message_func(task, limit_type, prefix, bucket)
 | |
| local message_func = function(_, limit_type, _, _)
 | |
|   return string.format('Ratelimit "%s" exceeded', limit_type)
 | |
| end
 | |
| 
 | |
| local rspamd_logger = require "rspamd_logger"
 | |
| local rspamd_util = require "rspamd_util"
 | |
| local rspamd_lua_utils = require "lua_util"
 | |
| local lua_redis = require "lua_redis"
 | |
| local fun = require "fun"
 | |
| local lua_maps = require "lua_maps"
 | |
| local lua_util = require "lua_util"
 | |
| local rspamd_hash = require "rspamd_cryptobox_hash"
 | |
| 
 | |
| 
 | |
| local function load_scripts(cfg, ev_base)
 | |
|   bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params)
 | |
|   bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params)
 | |
| end
 | |
| 
 | |
| local limit_parser
 | |
| local function parse_string_limit(lim, no_error)
 | |
|   local function parse_time_suffix(s)
 | |
|     if s == 's' then
 | |
|       return 1
 | |
|     elseif s == 'm' then
 | |
|       return 60
 | |
|     elseif s == 'h' then
 | |
|       return 3600
 | |
|     elseif s == 'd' then
 | |
|       return 86400
 | |
|     end
 | |
|   end
 | |
|   local function parse_num_suffix(s)
 | |
|     if s == '' then
 | |
|       return 1
 | |
|     elseif s == 'k' then
 | |
|       return 1000
 | |
|     elseif s == 'm' then
 | |
|       return 1000000
 | |
|     elseif s == 'g' then
 | |
|       return 1000000000
 | |
|     end
 | |
|   end
 | |
|   local lpeg = require "lpeg"
 | |
| 
 | |
|   if not limit_parser then
 | |
|     local digit = lpeg.R("09")
 | |
|     limit_parser = {}
 | |
|     limit_parser.integer =
 | |
|     (lpeg.S("+-") ^ -1) *
 | |
|             (digit   ^  1)
 | |
|     limit_parser.fractional =
 | |
|     (lpeg.P(".")   ) *
 | |
|             (digit ^ 1)
 | |
|     limit_parser.number =
 | |
|     (limit_parser.integer *
 | |
|             (limit_parser.fractional ^ -1)) +
 | |
|             (lpeg.S("+-") * limit_parser.fractional)
 | |
|     limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
 | |
|             (limit_parser.number / tonumber) *
 | |
|             ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
 | |
|       function (acc, val) return acc * val end)
 | |
|     limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
 | |
|             (limit_parser.number / tonumber) *
 | |
|             ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
 | |
|       function (acc, val) return acc * val end)
 | |
|     limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
 | |
|             (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
 | |
|             limit_parser.time)
 | |
|   end
 | |
|   local t = lpeg.match(limit_parser.limit, lim)
 | |
| 
 | |
|   if t and t[1] and t[2] and t[2] ~= 0 then
 | |
|     return t[2], t[1]
 | |
|   end
 | |
| 
 | |
|   if not no_error then
 | |
|     rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
 | |
|   end
 | |
| 
 | |
|   return nil
 | |
| end
 | |
| 
 | |
| local function parse_limit(name, data)
 | |
|   local buckets = {}
 | |
|   if type(data) == 'table' then
 | |
|     -- 3 cases here:
 | |
|     --  * old limit in format [burst, rate]
 | |
|     --  * vector of strings in Andrew's string format
 | |
|     --  * proper bucket table
 | |
|     if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
 | |
|       -- Old style ratelimit
 | |
|       rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
 | |
|       if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
 | |
|         table.insert(buckets, {
 | |
|           burst = data[1],
 | |
|           rate = data[2]
 | |
|         })
 | |
|       elseif data[1] ~= 0 then
 | |
|         rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
 | |
|       else
 | |
|         rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
 | |
|       end
 | |
|     else
 | |
|       -- Recursively map parse_limit and flatten the list
 | |
|       fun.each(function(l)
 | |
|         -- Flatten list
 | |
|         for _,b in ipairs(l) do table.insert(buckets, b) end
 | |
|       end, fun.map(function(d) return parse_limit(d, name) end, data))
 | |
|     end
 | |
|   elseif type(data) == 'string' then
 | |
|     local rep_rate, burst = parse_string_limit(data)
 | |
| 
 | |
|     if rep_rate and burst then
 | |
|       table.insert(buckets, {
 | |
|         burst = burst,
 | |
|         rate = 1.0 / rep_rate -- reciprocal
 | |
|       })
 | |
|     end
 | |
|   end
 | |
| 
 | |
|   -- Filter valid
 | |
|   return fun.totable(fun.filter(function(val)
 | |
|     return type(val.burst) == 'number' and type(val.rate) == 'number'
 | |
|   end, buckets))
 | |
| end
 | |
| 
 | |
| --- Check whether this addr is bounce
 | |
| local function check_bounce(from)
 | |
|   return fun.any(function(b) return b == from end, settings.bounce_senders)
 | |
| end
 | |
| 
 | |
| local keywords = {
 | |
|   ['ip'] = {
 | |
|     ['get_value'] = function(task)
 | |
|       local ip = task:get_ip()
 | |
|       if ip and ip:is_valid() then return tostring(ip) end
 | |
|       return nil
 | |
|     end,
 | |
|   },
 | |
|   ['rip'] = {
 | |
|     ['get_value'] = function(task)
 | |
|       local ip = task:get_ip()
 | |
|       if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end
 | |
|       return nil
 | |
|     end,
 | |
|   },
 | |
|   ['from'] = {
 | |
|     ['get_value'] = function(task)
 | |
|       local from = task:get_from(0)
 | |
|       if ((from or E)[1] or E).addr then
 | |
|         return string.lower(from[1]['addr'])
 | |
|       end
 | |
|       return nil
 | |
|     end,
 | |
|   },
 | |
|   ['bounce'] = {
 | |
|     ['get_value'] = function(task)
 | |
|       local from = task:get_from(0)
 | |
|       if not ((from or E)[1] or E).user then
 | |
|         return '_'
 | |
|       end
 | |
|       if check_bounce(from[1]['user']) then return '_' else return nil end
 | |
|     end,
 | |
|   },
 | |
|   ['asn'] = {
 | |
|     ['get_value'] = function(task)
 | |
|       local asn = task:get_mempool():get_variable('asn')
 | |
|       if not asn then
 | |
|         return nil
 | |
|       else
 | |
|         return asn
 | |
|       end
 | |
|     end,
 | |
|   },
 | |
|   ['user'] = {
 | |
|     ['get_value'] = function(task)
 | |
|       local auser = task:get_user()
 | |
|       if not auser then
 | |
|         return nil
 | |
|       else
 | |
|         return auser
 | |
|       end
 | |
|     end,
 | |
|   },
 | |
|   ['to'] = {
 | |
|     ['get_value'] = function(task)
 | |
|       return task:get_principal_recipient()
 | |
|     end,
 | |
|   },
 | |
| }
 | |
| 
 | |
| local function gen_rate_key(task, rtype, bucket)
 | |
|   local key_t = {tostring(lua_util.round(100000.0 / bucket.burst))}
 | |
|   local key_keywords = lua_util.str_split(rtype, '_')
 | |
|   local have_user = false
 | |
| 
 | |
|   for _, v in ipairs(key_keywords) do
 | |
|     local ret
 | |
| 
 | |
|     if keywords[v] and type(keywords[v]['get_value']) == 'function' then
 | |
|       ret = keywords[v]['get_value'](task)
 | |
|     end
 | |
|     if not ret then return nil end
 | |
|     if v == 'user' then have_user = true end
 | |
|     if type(ret) ~= 'string' then ret = tostring(ret) end
 | |
|     table.insert(key_t, ret)
 | |
|   end
 | |
| 
 | |
|   if have_user and not task:get_user() then
 | |
|     return nil
 | |
|   end
 | |
| 
 | |
|   return table.concat(key_t, ":")
 | |
| end
 | |
| 
 | |
| local function make_prefix(redis_key, name, bucket)
 | |
|   local hash_len = 24
 | |
|   if hash_len > #redis_key then hash_len = #redis_key end
 | |
|   local hash = settings.prefix ..
 | |
|       string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len)
 | |
|   -- Fill defaults
 | |
|   if not bucket.spam_factor_rate then
 | |
|     bucket.spam_factor_rate = settings.spam_factor_rate
 | |
|   end
 | |
|   if not bucket.ham_factor_rate then
 | |
|     bucket.ham_factor_rate = settings.ham_factor_rate
 | |
|   end
 | |
|   if not bucket.spam_factor_burst then
 | |
|     bucket.spam_factor_burst = settings.spam_factor_burst
 | |
|   end
 | |
|   if not bucket.ham_factor_burst then
 | |
|     bucket.ham_factor_burst = settings.ham_factor_burst
 | |
|   end
 | |
| 
 | |
|   return {
 | |
|     bucket = bucket,
 | |
|     name = name,
 | |
|     hash = hash
 | |
|   }
 | |
| end
 | |
| 
 | |
| local function limit_to_prefixes(task, k, v, prefixes)
 | |
|   local n = 0
 | |
|   for _,bucket in ipairs(v) do
 | |
|     local prefix = gen_rate_key(task, k, bucket)
 | |
| 
 | |
|     if prefix then
 | |
|       prefixes[prefix] = make_prefix(prefix, k, bucket)
 | |
|       n = n + 1
 | |
|     end
 | |
|   end
 | |
| 
 | |
|   return n
 | |
| end
 | |
| 
 | |
| local function ratelimit_cb(task)
 | |
|   if not settings.allow_local and
 | |
|           rspamd_lua_utils.is_rspamc_or_controller(task) then return end
 | |
| 
 | |
|   -- Get initial task data
 | |
|   local ip = task:get_from_ip()
 | |
|   if ip and ip:is_valid() and settings.whitelisted_ip then
 | |
|     if settings.whitelisted_ip:get_key(ip) then
 | |
|       -- Do not check whitelisted ip
 | |
|       rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP')
 | |
|       return
 | |
|     end
 | |
|   end
 | |
|   -- Parse all rcpts
 | |
|   local rcpts = task:get_recipients()
 | |
|   local rcpts_user = {}
 | |
|   if rcpts then
 | |
|     fun.each(function(r)
 | |
|       fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'})
 | |
|     end, rcpts)
 | |
| 
 | |
|     if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then
 | |
|       rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient')
 | |
|       return
 | |
|     end
 | |
|   end
 | |
|   -- Get user (authuser)
 | |
|   if settings.whitelisted_user then
 | |
|     local auser = task:get_user()
 | |
|     if settings.whitelisted_user:get_key(auser) then
 | |
|       rspamd_logger.infox(task, 'skip ratelimit for whitelisted user')
 | |
|       return
 | |
|     end
 | |
|   end
 | |
|   -- Now create all ratelimit prefixes
 | |
|   local prefixes = {}
 | |
|   local nprefixes = 0
 | |
| 
 | |
|   for k,v in pairs(settings.limits) do
 | |
|     nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes)
 | |
|   end
 | |
| 
 | |
|   for k, hdl in pairs(settings.custom_keywords or E) do
 | |
|     local ret, redis_key, bd = pcall(hdl, task)
 | |
| 
 | |
|     if ret then
 | |
|       local bucket = parse_limit(k, bd)
 | |
|       if bucket[1] then
 | |
|         prefixes[redis_key] = make_prefix(redis_key, k, bucket[1])
 | |
|       end
 | |
|       nprefixes = nprefixes + 1
 | |
|     else
 | |
|       rspamd_logger.errx(task, 'cannot call handler for %s: %s',
 | |
|           k, redis_key)
 | |
|     end
 | |
|   end
 | |
| 
 | |
|   local function gen_check_cb(prefix, bucket, lim_name)
 | |
|     return function(err, data)
 | |
|       if err then
 | |
|         rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data)
 | |
|       elseif type(data) == 'table' and data[1] and data[1] == 1 then
 | |
|         -- set symbol only and do NOT soft reject
 | |
|         if settings.symbol then
 | |
|           task:insert_result(settings.symbol, 0.0, lim_name .. "(" .. prefix .. ")")
 | |
|           rspamd_logger.infox(task,
 | |
|               'set_symbol_only: ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)',
 | |
|               lim_name, prefix,
 | |
|               bucket.burst, bucket.rate,
 | |
|               data[2], data[3], data[4])
 | |
|           return
 | |
|         -- set INFO symbol and soft reject
 | |
|         elseif settings.info_symbol then
 | |
|           task:insert_result(settings.info_symbol, 1.0,
 | |
|               lim_name .. "(" .. prefix .. ")")
 | |
|         end
 | |
|         rspamd_logger.infox(task,
 | |
|             'ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)',
 | |
|             lim_name, prefix,
 | |
|             bucket.burst, bucket.rate,
 | |
|             data[2], data[3], data[4])
 | |
|         task:set_pre_result('soft reject',
 | |
|                 message_func(task, lim_name, prefix, bucket))
 | |
|       end
 | |
|     end
 | |
|   end
 | |
| 
 | |
|   -- Don't do anything if pre-result has been already set
 | |
|   if task:has_pre_result() then return end
 | |
| 
 | |
|   if nprefixes > 0 then
 | |
|     -- Save prefixes to the cache to allow update
 | |
|     task:cache_set('ratelimit_prefixes', prefixes)
 | |
|     local now = rspamd_util.get_time()
 | |
|     now = lua_util.round(now * 1000.0) -- Get milliseconds
 | |
|     -- Now call check script for all defined prefixes
 | |
| 
 | |
|     for pr,value in pairs(prefixes) do
 | |
|       local bucket = value.bucket
 | |
|       local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms
 | |
|       rspamd_logger.debugm(N, task, "check limit %s:%s -> %s (%s/%s)",
 | |
|           value.name, pr, value.hash, bucket.burst, bucket.rate)
 | |
|       lua_redis.exec_redis_script(bucket_check_id,
 | |
|               {key = value.hash, task = task, is_write = true},
 | |
|               gen_check_cb(pr, bucket, value.name),
 | |
|               {value.hash, tostring(now), tostring(rate), tostring(bucket.burst),
 | |
|                   tostring(settings.expire)})
 | |
|     end
 | |
|   end
 | |
| end
 | |
| 
 | |
| local function ratelimit_update_cb(task)
 | |
|   local prefixes = task:cache_get('ratelimit_prefixes')
 | |
| 
 | |
|   if prefixes then
 | |
|     if task:has_pre_result() then
 | |
|       -- Already rate limited/greylisted, do nothing
 | |
|       rspamd_logger.debugm(N, task, 'pre-action has been set, do not update')
 | |
|       return
 | |
|     end
 | |
| 
 | |
|     local is_spam = not (task:get_metric_action() == 'no action')
 | |
| 
 | |
|     -- Update each bucket
 | |
|     for k, v in pairs(prefixes) do
 | |
|       local bucket = v.bucket
 | |
|       local function update_bucket_cb(err, data)
 | |
|         if err then
 | |
|           rspamd_logger.errx(task, 'cannot update rate bucket %s: %s',
 | |
|                   k, err)
 | |
|         else
 | |
|           rspamd_logger.debugm(N, task,
 | |
|               "updated limit %s:%s -> %s (%s/%s), burst: %s, dyn_rate: %s, dyn_burst: %s",
 | |
|               v.name, k, v.hash,
 | |
|               bucket.burst, bucket.rate,
 | |
|               data[1], data[2], data[3])
 | |
|         end
 | |
|       end
 | |
|       local now = rspamd_util.get_time()
 | |
|       now = lua_util.round(now * 1000.0) -- Get milliseconds
 | |
|       local mult_burst = bucket.ham_factor_burst or 1.0
 | |
|       local mult_rate = bucket.ham_factor_burst or 1.0
 | |
| 
 | |
|       if is_spam then
 | |
|         mult_burst = bucket.spam_factor_burst or 1.0
 | |
|         mult_rate = bucket.spam_factor_rate or 1.0
 | |
|       end
 | |
| 
 | |
|       lua_redis.exec_redis_script(bucket_update_id,
 | |
|               {key = v.hash, task = task, is_write = true},
 | |
|               update_bucket_cb,
 | |
|               {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst),
 | |
|                tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult),
 | |
|                tostring(settings.expire)})
 | |
|     end
 | |
|   end
 | |
| end
 | |
| 
 | |
| local opts = rspamd_config:get_all_opt(N)
 | |
| if opts then
 | |
| 
 | |
|   settings = lua_util.override_defaults(settings, opts)
 | |
| 
 | |
|   if opts['limit'] then
 | |
|     rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported')
 | |
|   end
 | |
| 
 | |
|   if opts['rates'] and type(opts['rates']) == 'table' then
 | |
|     -- new way of setting limits
 | |
|     fun.each(function(t, lim)
 | |
|       local buckets = parse_limit(t, lim)
 | |
| 
 | |
|       if buckets and #buckets > 0 then
 | |
|         settings.limits[t] = buckets
 | |
|       end
 | |
|     end, opts['rates'])
 | |
|   end
 | |
| 
 | |
|   local enabled_limits = fun.totable(fun.map(function(t)
 | |
|     return t
 | |
|   end, settings.limits))
 | |
|   rspamd_logger.infox(rspamd_config,
 | |
|           'enabled rate buckets: [%1]', table.concat(enabled_limits, ','))
 | |
| 
 | |
|   -- Ret, ret, ret: stupid legacy stuff:
 | |
|   -- If we have a string with commas then load it as as static map
 | |
|   -- otherwise, apply normal logic of Rspamd maps
 | |
| 
 | |
|   local wrcpts = opts['whitelisted_rcpts']
 | |
|   if type(wrcpts) == 'string' then
 | |
|     if string.find(wrcpts, ',') then
 | |
|       settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
 | |
|         lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts')
 | |
|     else
 | |
|       settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
 | |
|         'Ratelimit whitelisted rcpts')
 | |
|     end
 | |
|   elseif type(opts['whitelisted_rcpts']) == 'table' then
 | |
|     settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
 | |
|       'Ratelimit whitelisted rcpts')
 | |
|   else
 | |
|     -- Stupid default...
 | |
|     settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
 | |
|         settings.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts')
 | |
|   end
 | |
| 
 | |
|   if opts['whitelisted_ip'] then
 | |
|     settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix',
 | |
|       'Ratelimit whitelist ip map')
 | |
|   end
 | |
| 
 | |
|   if opts['whitelisted_user'] then
 | |
|     settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set',
 | |
|       'Ratelimit whitelist user map')
 | |
|   end
 | |
| 
 | |
|   settings.custom_keywords = {}
 | |
|   if opts['custom_keywords'] then
 | |
|     local ret, res_or_err = pcall(loadfile(opts['custom_keywords']))
 | |
| 
 | |
|     if ret then
 | |
|       opts['custom_keywords'] = {}
 | |
|       if type(res_or_err) == 'table' then
 | |
|         for k,hdl in pairs(res_or_err) do
 | |
|           settings['custom_keywords'][k] = hdl
 | |
|         end
 | |
|       elseif type(res_or_err) == 'function' then
 | |
|         settings['custom_keywords']['custom'] = res_or_err
 | |
|       end
 | |
|     else
 | |
|       rspamd_logger.errx(rspamd_config, 'cannot execute %s: %s',
 | |
|           opts['custom_keywords'], res_or_err)
 | |
|       settings['custom_keywords'] = {}
 | |
|     end
 | |
|   end
 | |
| 
 | |
|   if opts['message_func'] then
 | |
|     message_func = assert(load(opts['message_func']))()
 | |
|   end
 | |
| 
 | |
|   redis_params = lua_redis.parse_redis_server('ratelimit')
 | |
| 
 | |
|   if not redis_params then
 | |
|     rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module')
 | |
|     lua_util.disable_module(N, "redis")
 | |
|   else
 | |
|     local s = {
 | |
|       type = 'prefilter,nostat',
 | |
|       name = 'RATELIMIT_CHECK',
 | |
|       priority = 7,
 | |
|       callback = ratelimit_cb,
 | |
|       flags = 'empty',
 | |
|     }
 | |
| 
 | |
|     if settings.symbol then
 | |
|       s.name = settings.symbol
 | |
|     elseif settings.info_symbol then
 | |
|       s.name = settings.info_symbol
 | |
|     end
 | |
| 
 | |
|     rspamd_config:register_symbol(s)
 | |
|     rspamd_config:register_symbol {
 | |
|       type = 'idempotent',
 | |
|       name = 'RATELIMIT_UPDATE',
 | |
|       callback = ratelimit_update_cb,
 | |
|     }
 | |
|   end
 | |
| end
 | |
| 
 | |
| rspamd_config:add_on_load(function(cfg, ev_base, worker)
 | |
|   load_scripts(cfg, ev_base)
 | |
| end)
 |