diff --git a/README.md b/README.md index 1f7a084..18ae0e7 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,7 @@ docker-compose -f docker-compose.cuda.yml up -d --build | --get-api-key-link | Show a link in the UI where to direct users to get an API key | `Don't show a link` | LT_GET_API_KEY_LINK | | --require-api-key-origin | Require use of an API key for programmatic access to the API, unless the request origin matches this domain | `No restrictions on domain origin` | LT_REQUIRE_API_KEY_ORIGIN | | --require-api-key-secret | Require use of an API key for programmatic access to the API, unless the client also sends a secret match | `No secrets required` | LT_REQUIRE_API_KEY_SECRET | +| --shared-storage | Shared storage URI to use for multi-process data sharing (e.g. when using gunicorn) | `memory://` | LT_SHARED_STORAGE | | --load-only | Set available languages | `all from argostranslate` | LT_LOAD_ONLY | | --threads | Set number of threads | `4` | LT_THREADS | | --suggestions | Allow user suggestions | `False` | LT_SUGGESTIONS | diff --git a/libretranslate/app.py b/libretranslate/app.py index d6d7ec0..ea4c0b6 100644 --- a/libretranslate/app.py +++ b/libretranslate/app.py @@ -21,7 +21,7 @@ from werkzeug.exceptions import HTTPException from werkzeug.http import http_date from flask_babel import Babel -from libretranslate import flood, remove_translated_files, security +from libretranslate import scheduler, flood, secret, remove_translated_files, security, storage from libretranslate.language import detect_languages, improve_translation_formatting from libretranslate.locales import (_, _lazy, get_available_locales, get_available_locale_codes, gettext_escaped, gettext_html, lazy_swag, get_alternate_locale_links) @@ -127,6 +127,8 @@ def create_app(args): bp = Blueprint('Main app', __name__) + storage.setup(args.shared_storage) + if not args.disable_files_translation: remove_translated_files.setup(get_upload_dir()) languages = load_languages() @@ -202,8 +204,12 @@ def create_app(args): limiter = Limiter() - if args.req_flood_threshold > 0: - flood.setup(args.req_flood_threshold) + if not "gunicorn" in os.environ.get("SERVER_SOFTWARE", ""): + # Gunicorn starts the scheduler in the master process + scheduler.setup(args) + + flood.setup(args) + secret.setup(args) measure_request = None gauge_request = None @@ -261,16 +267,16 @@ def create_app(args): if (args.require_api_key_secret and key_missing - and not flood.secret_match(get_req_secret()) + and not secret.secret_match(get_req_secret()) ): need_key = True - + if need_key: description = _("Please contact the server operator to get an API key") if args.get_api_key_link: description = _("Visit %(url)s to get an API key", url=args.get_api_key_link) abort( - 403, + 400, description=description, ) return f(*a, **kw) @@ -347,7 +353,7 @@ def create_app(args): response = Response(render_template("app.js.template", url_prefix=args.url_prefix, get_api_key_link=args.get_api_key_link, - api_secret=flood.get_current_secret() if args.require_api_key_secret else ""), content_type='application/javascript; charset=utf-8') + api_secret=secret.get_current_secret() if args.require_api_key_secret else ""), content_type='application/javascript; charset=utf-8') if args.require_api_key_secret: response.headers['Last-Modified'] = http_date(datetime.now()) diff --git a/libretranslate/default_values.py b/libretranslate/default_values.py index a9d007b..3e14580 100644 --- a/libretranslate/default_values.py +++ b/libretranslate/default_values.py @@ -136,6 +136,11 @@ _default_options_objects = [ 'default_value': False, 'value_type': 'bool' }, + { + 'name': 'SHARED_STORAGE', + 'default_value': 'memory://', + 'value_type': 'str' + }, { 'name': 'LOAD_ONLY', 'default_value': None, diff --git a/libretranslate/flood.py b/libretranslate/flood.py index 9d2d711..63cd392 100644 --- a/libretranslate/flood.py +++ b/libretranslate/flood.py @@ -1,75 +1,47 @@ -import atexit -import random -import string +from libretranslate.storage import get_storage -from apscheduler.schedulers.background import BackgroundScheduler - -def generate_secret(): - return ''.join(random.choices(string.ascii_uppercase + string.digits, k=7)) - -banned = {} active = False threshold = -1 -secrets = [generate_secret(), generate_secret()] def forgive_banned(): - global banned global threshold clear_list = [] + s = get_storage() + banned = s.get_all_hash_int("banned") for ip in banned: if banned[ip] <= 0: clear_list.append(ip) else: - banned[ip] = min(threshold, banned[ip]) - 1 + s.set_hash_int("banned", ip, min(threshold, banned[ip]) - 1) for ip in clear_list: - del banned[ip] + s.del_hash("banned", ip) -def rotate_secrets(): - global secrets - secrets[0] = secrets[1] - secrets[1] = generate_secret() - -def secret_match(s): - return s in secrets - -def get_current_secret(): - return secrets[1] - -def setup(violations_threshold=100): +def setup(args): global active global threshold - active = True - threshold = violations_threshold - - scheduler = BackgroundScheduler() - scheduler.add_job(func=forgive_banned, trigger="interval", minutes=30) - scheduler.add_job(func=rotate_secrets, trigger="interval", minutes=30) - - scheduler.start() - - # Shut down the scheduler when exiting the app - atexit.register(lambda: scheduler.shutdown()) - + if args.req_flood_threshold > 0: + active = True + threshold = args.req_flood_threshold def report(request_ip): if active: - banned[request_ip] = banned.get(request_ip, 0) - banned[request_ip] += 1 - + get_storage().inc_hash_int("banned", request_ip) def decrease(request_ip): - if banned[request_ip] > 0: - banned[request_ip] -= 1 - + s = get_storage() + if s.get_hash_int("banned", request_ip) > 0: + s.dec_hash_int("banned", request_ip) def has_violation(request_ip): - return request_ip in banned and banned[request_ip] > 0 - + s = get_storage() + return s.get_hash_int("banned", request_ip) > 0 def is_banned(request_ip): + s = get_storage() + # More than X offences? - return active and banned.get(request_ip, 0) >= threshold + return active and s.get_hash_int("banned", request_ip) >= threshold diff --git a/libretranslate/main.py b/libretranslate/main.py index 5a58885..0e02145 100644 --- a/libretranslate/main.py +++ b/libretranslate/main.py @@ -126,6 +126,13 @@ def get_args(): action="store_true", help="Require use of an API key for programmatic access to the API, unless the client also sends a secret match", ) + parser.add_argument( + "--shared-storage", + type=str, + default=DEFARGS['SHARED_STORAGE'], + metavar="", + help="Shared storage URI to use for multi-process data sharing (e.g. via gunicorn)", + ) parser.add_argument( "--load-only", type=operator.methodcaller("split", ","), diff --git a/libretranslate/scheduler.py b/libretranslate/scheduler.py new file mode 100644 index 0000000..3300095 --- /dev/null +++ b/libretranslate/scheduler.py @@ -0,0 +1,23 @@ +import atexit +from apscheduler.schedulers.background import BackgroundScheduler +scheduler = None + +def setup(args): + from libretranslate.flood import forgive_banned + from libretranslate.secret import rotate_secrets + + global scheduler + + if scheduler is None: + scheduler = BackgroundScheduler() + + if args.req_flood_threshold > 0: + scheduler.add_job(func=forgive_banned, trigger="interval", minutes=10) + + if args.api_keys and args.require_api_key_secret: + scheduler.add_job(func=rotate_secrets, trigger="interval", minutes=30) + + scheduler.start() + + # Shut down the scheduler when exiting the app + atexit.register(lambda: scheduler.shutdown()) \ No newline at end of file diff --git a/libretranslate/secret.py b/libretranslate/secret.py new file mode 100644 index 0000000..ce69910 --- /dev/null +++ b/libretranslate/secret.py @@ -0,0 +1,28 @@ +import atexit +import random +import string + +from libretranslate.storage import get_storage + +def generate_secret(): + return ''.join(random.choices(string.ascii_uppercase + string.digits, k=7)) + +def rotate_secrets(): + s = get_storage() + secret_1 = s.get_str("secret_1") + s.set_str("secret_0", secret_1) + s.set_str("secret_1", generate_secret()) + + +def secret_match(secret): + s = get_storage() + return secret == s.get_str("secret_0") or secret == s.get_str("secret_1") + +def get_current_secret(): + return get_storage().get_str("secret_1") + +def setup(args): + if args.api_keys and args.require_api_key_secret: + s = get_storage() + s.set_str("secret_0", generate_secret()) + s.set_str("secret_1", generate_secret()) diff --git a/libretranslate/storage.py b/libretranslate/storage.py new file mode 100644 index 0000000..7ffe849 --- /dev/null +++ b/libretranslate/storage.py @@ -0,0 +1,158 @@ +import redis + +storage = None +def get_storage(): + return storage + +class Storage: + def set_bool(self, key, value): + raise Exception("not implemented") + def get_bool(self, key): + raise Exception("not implemented") + + def set_int(self, key, value): + raise Exception("not implemented") + def get_int(self, key): + raise Exception("not implemented") + + def set_str(self, key, value): + raise Exception("not implemented") + def get_str(self, key): + raise Exception("not implemented") + + def set_hash_int(self, ns, key, value): + raise Exception("not implemented") + def get_hash_int(self, ns, key): + raise Exception("not implemented") + def inc_hash_int(self, ns, key): + raise Exception("not implemented") + def dec_hash_int(self, ns, key): + raise Exception("not implemented") + + def get_hash_keys(self, ns): + raise Exception("not implemented") + def del_hash(self, ns, key): + raise Exception("not implemented") + +class MemoryStorage(Storage): + def __init__(self): + self.store = {} + + def set_bool(self, key, value): + self.store[key] = bool(value) + + def get_bool(self, key): + return bool(self.store[key]) + + def set_int(self, key, value): + self.store[key] = int(value) + + def get_int(self, key): + return int(self.store.get(key, 0)) + + def set_str(self, key, value): + self.store[key] = value + + def get_str(self, key): + return str(self.store.get(key, "")) + + def set_hash_int(self, ns, key, value): + if ns not in self.store: + self.store[ns] = {} + self.store[ns][key] = int(value) + + def get_hash_int(self, ns, key): + d = self.store.get(ns, {}) + return int(d.get(key, 0)) + + def inc_hash_int(self, ns, key): + if ns not in self.store: + self.store[ns] = {} + + if key not in self.store[ns]: + self.store[ns][key] = 0 + else: + self.store[ns][key] += 1 + + def dec_hash_int(self, ns, key): + if ns not in self.store: + self.store[ns] = {} + + if key not in self.store[ns]: + self.store[ns][key] = 0 + else: + self.store[ns][key] -= 1 + + def get_all_hash_int(self, ns): + if ns in self.store: + return [{str(k): int(v)} for k,v in self.store[ns].items()] + else: + return [] + + def del_hash(self, ns, key): + del self.store[ns][key] + + +class RedisStorage(Storage): + def __init__(self, redis_uri): + self.conn = redis.from_url(redis_uri) + self.conn.ping() + + def set_bool(self, key, value): + self.conn.set(key, "1" if value else "0") + + def get_bool(self, key): + return bool(self.conn.get(key)) + + def set_int(self, key, value): + self.conn.set(key, str(value)) + + def get_int(self, key): + v = self.conn.get(key) + if v is None: + return 0 + else: + return v + + def set_str(self, key, value): + self.conn.set(key, value) + + def get_str(self, key): + v = self.conn.get(key) + if v is None: + return "" + else: + return v.decode('utf-8') + + def get_hash_int(self, ns, key): + v = self.conn.hget(ns, key) + if v is None: + return 0 + else: + return int(v) + + def set_hash_int(self, ns, key, value): + self.conn.hset(ns, key, value) + + def inc_hash_int(self, ns, key): + return int(self.conn.hincrby(ns, key)) + + def dec_hash_int(self, ns, key): + return int(self.conn.hincrby(ns, key, -1)) + + def get_all_hash_int(self, ns): + return {k.decode("utf-8"): int(v) for k,v in self.conn.hgetall(ns).items()} + + def del_hash(self, ns, key): + self.conn.hdel(ns, key) + +def setup(storage_uri): + global storage + if storage_uri.startswith("memory://"): + storage = MemoryStorage() + elif storage_uri.startswith("redis://"): + storage = RedisStorage(storage_uri) + else: + raise Exception("Invalid storage URI: " + storage_uri) + + return storage \ No newline at end of file diff --git a/libretranslate/templates/app.js.template b/libretranslate/templates/app.js.template index 4992d5b..b788830 100644 --- a/libretranslate/templates/app.js.template +++ b/libretranslate/templates/app.js.template @@ -243,9 +243,8 @@ document.addEventListener('DOMContentLoaded', function(){ request.onload = function() { try{ {% if api_secret != "" %} - if (this.status === 403){ - window.location.reload(true); - return; + if (this.status === 400){ + if (self.refreshOnce()) return; } {% endif %} @@ -362,6 +361,15 @@ document.addEventListener('DOMContentLoaded', function(){ this.translatedFileUrl = false; this.loadingFileTranslation = false; }, + refreshOnce: function(){ + var lastRefreshed = parseInt(localStorage.getItem("refreshed") || 0); + var now = new Date().getTime(); + if (now - lastRefreshed > 1000 * 60 * 1){ + localStorage.setItem("refreshed", now); + window.location.reload(); + return true; + } + }, translateFile: function(e) { e.preventDefault(); @@ -383,9 +391,8 @@ document.addEventListener('DOMContentLoaded', function(){ if (translateFileRequest.readyState === 4 && translateFileRequest.status === 200) { try{ {% if api_secret != "" %} - if (this.status === 403){ - window.location.reload(true); - return; + if (this.status === 400){ + if (self.refreshOnce()) return; } {% endif %} self.loadingFileTranslation = false; diff --git a/scripts/gunicorn_conf.py b/scripts/gunicorn_conf.py index 6ec7e40..845da5b 100644 --- a/scripts/gunicorn_conf.py +++ b/scripts/gunicorn_conf.py @@ -1,4 +1,42 @@ from prometheus_client import multiprocess +import re +import sys def child_exit(server, worker): - multiprocess.mark_process_dead(worker.pid) \ No newline at end of file + multiprocess.mark_process_dead(worker.pid) + +def on_starting(server): + # Parse command line arguments + proc_name = server.cfg.default_proc_name + kwargs = {} + if proc_name.startswith("wsgi:app"): + str_args = re.sub('wsgi:app\s*\(\s*(.*)\s*\)', '\\1', proc_name).strip().split(",") + for a in str_args: + if "=" in a: + k,v = a.split("=") + k = k.strip() + v = v.strip() + + if v.lower() in ["true", "false"]: + v = v.lower() == "true" + elif v[0] == '"': + v = v[1:-1] + kwargs[k] = v + + from libretranslate.main import get_args + sys.argv = ['--wsgi'] + for k in kwargs: + ck = k.replace("_", "-") + if isinstance(kwargs[k], bool) and kwargs[k]: + sys.argv.append("--" + ck) + else: + sys.argv.append("--" + ck) + sys.argv.append(kwargs[k]) + + args = get_args() + + from libretranslate import storage, scheduler, flood, secret + storage.setup(args.shared_storage) + scheduler.setup(args) + flood.setup(args) + secret.setup(args) \ No newline at end of file