mirror of
https://github.com/Mailu/Mailu.git
synced 2024-12-16 10:59:53 +02:00
73ab4327c2
this is working fine, but introduces a sqlalchemy warning when using config-import: /app/mailu/schemas.py:822: SAWarning: Identity map already had an identity for (...), replacing it with newly flushed object. Are there load operations occurring inside of an event handler within the flush?
1275 lines
44 KiB
Python
1275 lines
44 KiB
Python
""" Mailu marshmallow fields and schema
|
|
"""
|
|
|
|
from copy import deepcopy
|
|
from collections import Counter
|
|
from datetime import timezone
|
|
|
|
import json
|
|
import logging
|
|
import yaml
|
|
|
|
import sqlalchemy
|
|
|
|
from marshmallow import pre_load, post_load, post_dump, fields, Schema
|
|
from marshmallow.utils import ensure_text_type
|
|
from marshmallow.exceptions import ValidationError
|
|
from marshmallow_sqlalchemy import SQLAlchemyAutoSchemaOpts
|
|
from marshmallow_sqlalchemy.fields import RelatedList
|
|
|
|
from flask_marshmallow import Marshmallow
|
|
|
|
from OpenSSL import crypto
|
|
|
|
from pygments import highlight
|
|
from pygments.token import Token
|
|
from pygments.lexers import get_lexer_by_name
|
|
from pygments.lexers.data import YamlLexer
|
|
from pygments.formatters import get_formatter_by_name
|
|
|
|
from mailu import models, dkim
|
|
|
|
|
|
ma = Marshmallow()
|
|
|
|
|
|
### import logging and schema colorization ###
|
|
|
|
_model2schema = {}
|
|
|
|
def get_schema(cls=None):
|
|
""" return schema class for model """
|
|
if cls is None:
|
|
return _model2schema.values()
|
|
return _model2schema.get(cls)
|
|
|
|
def mapped(cls):
|
|
""" register schema in model2schema map """
|
|
_model2schema[cls.Meta.model] = cls
|
|
return cls
|
|
|
|
class Logger:
|
|
""" helps with counting and colorizing
|
|
imported and exported data
|
|
"""
|
|
|
|
class MyYamlLexer(YamlLexer):
|
|
""" colorize yaml constants and integers """
|
|
def get_tokens(self, text, unfiltered=False):
|
|
for typ, value in super().get_tokens(text, unfiltered):
|
|
if typ is Token.Literal.Scalar.Plain:
|
|
if value in {'true', 'false', 'null'}:
|
|
typ = Token.Keyword.Constant
|
|
elif value == HIDDEN:
|
|
typ = Token.Error
|
|
else:
|
|
try:
|
|
int(value, 10)
|
|
except ValueError:
|
|
try:
|
|
float(value)
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
typ = Token.Literal.Number.Float
|
|
else:
|
|
typ = Token.Literal.Number.Integer
|
|
yield typ, value
|
|
|
|
def __init__(self, want_color=None, can_color=False, debug=False, secrets=False):
|
|
|
|
self.lexer = 'yaml'
|
|
self.formatter = 'terminal'
|
|
self.strip = False
|
|
self.verbose = 0
|
|
self.quiet = False
|
|
self.secrets = secrets
|
|
self.debug = debug
|
|
self.print = print
|
|
|
|
self.color = want_color or can_color
|
|
|
|
self._counter = Counter()
|
|
self._schemas = {}
|
|
|
|
# log contexts
|
|
self._diff_context = {
|
|
'full': True,
|
|
'secrets': secrets,
|
|
}
|
|
log_context = {
|
|
'secrets': secrets,
|
|
}
|
|
|
|
# register listeners
|
|
for schema in get_schema():
|
|
model = schema.Meta.model
|
|
self._schemas[model] = schema(context=log_context)
|
|
sqlalchemy.event.listen(model, 'after_insert', self._listen_insert)
|
|
sqlalchemy.event.listen(model, 'after_update', self._listen_update)
|
|
sqlalchemy.event.listen(model, 'after_delete', self._listen_delete)
|
|
|
|
# special listener for dkim_key changes
|
|
# TODO: _listen_dkim can be removed when dkim keys are stored in database
|
|
self._dedupe_dkim = set()
|
|
sqlalchemy.event.listen(models.db.session, 'after_flush', self._listen_dkim)
|
|
|
|
# register debug logger for sqlalchemy
|
|
if self.debug:
|
|
logging.basicConfig()
|
|
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
|
|
|
|
def _log(self, action, target, message=None):
|
|
if message is None:
|
|
try:
|
|
message = self._schemas[target.__class__].dump(target)
|
|
except KeyError:
|
|
message = target
|
|
if not isinstance(message, str):
|
|
message = repr(message)
|
|
self.print(f'{action} {target.__table__}: {self.colorize(message)}')
|
|
|
|
def _listen_insert(self, mapper, connection, target): # pylint: disable=unused-argument
|
|
""" callback method to track import """
|
|
self._counter.update([('Created', target.__table__.name)])
|
|
if self.verbose:
|
|
self._log('Created', target)
|
|
|
|
def _listen_update(self, mapper, connection, target): # pylint: disable=unused-argument
|
|
""" callback method to track import """
|
|
|
|
changes = {}
|
|
inspection = sqlalchemy.inspect(target)
|
|
for attr in sqlalchemy.orm.class_mapper(target.__class__).column_attrs:
|
|
history = getattr(inspection.attrs, attr.key).history
|
|
if history.has_changes() and history.deleted:
|
|
before = history.deleted[-1]
|
|
after = getattr(target, attr.key)
|
|
# we don't have ordered lists
|
|
if isinstance(before, list):
|
|
before = set(before)
|
|
if isinstance(after, list):
|
|
after = set(after)
|
|
# TODO: this can be removed when comment is not nullable in model
|
|
if attr.key == 'comment' and not before and not after:
|
|
pass
|
|
# only remember changed keys
|
|
elif before != after:
|
|
if self.verbose:
|
|
changes[str(attr.key)] = (before, after)
|
|
else:
|
|
break
|
|
|
|
if self.verbose:
|
|
# use schema to log changed attributes
|
|
schema = get_schema(target.__class__)
|
|
only = set(changes.keys()) & set(schema().fields.keys())
|
|
if only:
|
|
for key, value in schema(
|
|
only=only,
|
|
context=self._diff_context
|
|
).dump(target).items():
|
|
before, after = changes[key]
|
|
if value == HIDDEN:
|
|
before = HIDDEN if before else before
|
|
after = HIDDEN if after else after
|
|
else:
|
|
# also hide this
|
|
after = value
|
|
self._log('Modified', target, f'{str(target)!r} {key}: {before!r} -> {after!r}')
|
|
|
|
if changes:
|
|
self._counter.update([('Modified', target.__table__.name)])
|
|
|
|
def _listen_delete(self, mapper, connection, target): # pylint: disable=unused-argument
|
|
""" callback method to track import """
|
|
self._counter.update([('Deleted', target.__table__.name)])
|
|
if self.verbose:
|
|
self._log('Deleted', target)
|
|
|
|
# TODO: _listen_dkim can be removed when dkim keys are stored in database
|
|
def _listen_dkim(self, session, flush_context): # pylint: disable=unused-argument
|
|
""" callback method to track import """
|
|
for target in session.identity_map.values():
|
|
# look at Domains originally loaded from db
|
|
if not isinstance(target, models.Domain) or not target._sa_instance_state.load_path:
|
|
continue
|
|
before = target._dkim_key_on_disk
|
|
after = target._dkim_key
|
|
# "de-dupe" messages; this event is fired at every flush
|
|
if before == after or (target, before, after) in self._dedupe_dkim:
|
|
continue
|
|
self._dedupe_dkim.add((target, before, after))
|
|
self._counter.update([('Modified', target.__table__.name)])
|
|
if self.verbose:
|
|
if self.secrets:
|
|
before = before.decode('ascii', 'ignore')
|
|
after = after.decode('ascii', 'ignore')
|
|
else:
|
|
before = HIDDEN if before else ''
|
|
after = HIDDEN if after else ''
|
|
self._log('Modified', target, f'{str(target)!r} dkim_key: {before!r} -> {after!r}')
|
|
|
|
def track_serialize(self, obj, item, backref=None):
|
|
""" callback method to track import """
|
|
# called for backref modification?
|
|
if backref is not None:
|
|
self._log(
|
|
'Modified', item, '{target!r} {key}: {before!r} -> {after!r}'.format_map(backref))
|
|
return
|
|
# show input data?
|
|
if self.verbose < 2:
|
|
return
|
|
# hide secrets in data
|
|
if not self.secrets:
|
|
item = self._schemas[obj.opts.model].hide(item)
|
|
if 'hash_password' in item:
|
|
item['password'] = HIDDEN
|
|
if 'fetches' in item:
|
|
for fetch in item['fetches']:
|
|
fetch['password'] = HIDDEN
|
|
self._log('Handling', obj.opts.model, item)
|
|
|
|
def changes(self, *messages, **kwargs):
|
|
""" show changes gathered in counter """
|
|
if self.quiet:
|
|
return
|
|
if self._counter:
|
|
changes = []
|
|
last = None
|
|
for (action, what), count in sorted(self._counter.items()):
|
|
if action != last:
|
|
if last:
|
|
changes.append('/')
|
|
changes.append(f'{action}:')
|
|
last = action
|
|
changes.append(f'{what}({count})')
|
|
else:
|
|
changes = ['No changes.']
|
|
self.print(*messages, *changes, **kwargs)
|
|
|
|
def _format_errors(self, store, path=None):
|
|
|
|
res = []
|
|
if path is None:
|
|
path = []
|
|
for key in sorted(store):
|
|
location = path + [str(key)]
|
|
value = store[key]
|
|
if isinstance(value, dict):
|
|
res.extend(self._format_errors(value, location))
|
|
else:
|
|
for message in value:
|
|
res.append((".".join(location), message))
|
|
|
|
if path:
|
|
return res
|
|
|
|
maxlen = max(len(loc) for loc, msg in res)
|
|
res = [f' - {loc.ljust(maxlen)} : {msg}' for loc, msg in res]
|
|
errors = f'{len(res)} error{["s",""][len(res)==1]}'
|
|
res.insert(0, f'[ValidationError] {errors} occurred during input validation')
|
|
|
|
return '\n'.join(res)
|
|
|
|
def _is_validation_error(self, exc):
|
|
""" walk traceback to extract invalid field from marshmallow """
|
|
path = []
|
|
trace = exc.__traceback__
|
|
while trace:
|
|
if trace.tb_frame.f_code.co_name == '_serialize':
|
|
if 'attr' in trace.tb_frame.f_locals:
|
|
path.append(trace.tb_frame.f_locals['attr'])
|
|
elif trace.tb_frame.f_code.co_name == '_init_fields':
|
|
spec = ', '.join(
|
|
'.'.join(path + [key])
|
|
for key in trace.tb_frame.f_locals['invalid_fields'])
|
|
return f'Invalid filter: {spec}'
|
|
trace = trace.tb_next
|
|
return None
|
|
|
|
def format_exception(self, exc):
|
|
""" format ValidationErrors and other exceptions when not debugging """
|
|
if isinstance(exc, ValidationError):
|
|
return self._format_errors(exc.messages)
|
|
if isinstance(exc, ValueError):
|
|
if msg := self._is_validation_error(exc):
|
|
return msg
|
|
if self.debug:
|
|
return None
|
|
msg = ' '.join(str(exc).split())
|
|
return f'[{exc.__class__.__name__}] {msg}'
|
|
|
|
colorscheme = {
|
|
Token: ('', ''),
|
|
Token.Name.Tag: ('cyan', 'cyan'),
|
|
Token.Literal.Scalar: ('green', 'green'),
|
|
Token.Literal.String: ('green', 'green'),
|
|
Token.Name.Constant: ('green', 'green'), # multiline strings
|
|
Token.Keyword.Constant: ('magenta', 'magenta'),
|
|
Token.Literal.Number: ('magenta', 'magenta'),
|
|
Token.Error: ('red', 'red'),
|
|
Token.Name: ('red', 'red'),
|
|
Token.Operator: ('red', 'red'),
|
|
}
|
|
|
|
def colorize(self, data, lexer=None, formatter=None, color=None, strip=None):
|
|
""" add ANSI color to data """
|
|
|
|
if color is False or not self.color:
|
|
return data
|
|
|
|
lexer = lexer or self.lexer
|
|
lexer = Logger.MyYamlLexer() if lexer == 'yaml' else get_lexer_by_name(lexer)
|
|
formatter = get_formatter_by_name(formatter or self.formatter, colorscheme=self.colorscheme)
|
|
if strip is None:
|
|
strip = self.strip
|
|
|
|
res = highlight(data, lexer, formatter)
|
|
if strip:
|
|
return res.rstrip('\n')
|
|
return res
|
|
|
|
|
|
### marshmallow render modules ###
|
|
|
|
# hidden attributes
|
|
class _Hidden:
|
|
def __bool__(self):
|
|
return False
|
|
def __copy__(self):
|
|
return self
|
|
def __deepcopy__(self, _):
|
|
return self
|
|
def __eq__(self, other):
|
|
return str(other) == '<hidden>'
|
|
def __repr__(self):
|
|
return '<hidden>'
|
|
__str__ = __repr__
|
|
|
|
yaml.add_representer(
|
|
_Hidden,
|
|
lambda dumper, data: dumper.represent_data(str(data))
|
|
)
|
|
|
|
HIDDEN = _Hidden()
|
|
|
|
# multiline attributes
|
|
class _Multiline(str):
|
|
pass
|
|
|
|
yaml.add_representer(
|
|
_Multiline,
|
|
lambda dumper, data: dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')
|
|
|
|
)
|
|
|
|
# yaml render module
|
|
class RenderYAML:
|
|
""" Marshmallow YAML Render Module
|
|
"""
|
|
|
|
class SpacedDumper(yaml.Dumper):
|
|
""" YAML Dumper to add a newline between main sections
|
|
and double the indent used
|
|
"""
|
|
|
|
def write_line_break(self, data=None):
|
|
super().write_line_break(data)
|
|
if len(self.indents) == 1:
|
|
super().write_line_break()
|
|
|
|
def increase_indent(self, flow=False, indentless=False):
|
|
return super().increase_indent(flow, False)
|
|
|
|
@staticmethod
|
|
def _augment(kwargs, defaults):
|
|
""" add defaults to kwargs if missing
|
|
"""
|
|
for key, value in defaults.items():
|
|
if key not in kwargs:
|
|
kwargs[key] = value
|
|
|
|
_load_defaults = {}
|
|
@classmethod
|
|
def loads(cls, *args, **kwargs):
|
|
""" load yaml data from string
|
|
"""
|
|
cls._augment(kwargs, cls._load_defaults)
|
|
return yaml.safe_load(*args, **kwargs)
|
|
|
|
_dump_defaults = {
|
|
'Dumper': SpacedDumper,
|
|
'default_flow_style': False,
|
|
'allow_unicode': True,
|
|
'sort_keys': False,
|
|
}
|
|
@classmethod
|
|
def dumps(cls, *args, **kwargs):
|
|
""" dump data to yaml string
|
|
"""
|
|
cls._augment(kwargs, cls._dump_defaults)
|
|
return yaml.dump(*args, **kwargs)
|
|
|
|
# json encoder
|
|
class JSONEncoder(json.JSONEncoder):
|
|
""" JSONEncoder supporting serialization of HIDDEN """
|
|
def default(self, o):
|
|
""" serialize HIDDEN """
|
|
if isinstance(o, _Hidden):
|
|
return str(o)
|
|
return json.JSONEncoder.default(self, o)
|
|
|
|
# json render module
|
|
class RenderJSON:
|
|
""" Marshmallow JSON Render Module
|
|
"""
|
|
|
|
@staticmethod
|
|
def _augment(kwargs, defaults):
|
|
""" add defaults to kwargs if missing
|
|
"""
|
|
for key, value in defaults.items():
|
|
if key not in kwargs:
|
|
kwargs[key] = value
|
|
|
|
_load_defaults = {}
|
|
@classmethod
|
|
def loads(cls, *args, **kwargs):
|
|
""" load json data from string
|
|
"""
|
|
cls._augment(kwargs, cls._load_defaults)
|
|
return json.loads(*args, **kwargs)
|
|
|
|
_dump_defaults = {
|
|
'separators': (',',':'),
|
|
'cls': JSONEncoder,
|
|
}
|
|
@classmethod
|
|
def dumps(cls, *args, **kwargs):
|
|
""" dump data to json string
|
|
"""
|
|
cls._augment(kwargs, cls._dump_defaults)
|
|
return json.dumps(*args, **kwargs)
|
|
|
|
|
|
### marshmallow: custom fields ###
|
|
|
|
def _rfc3339(datetime):
|
|
""" dump datetime according to rfc3339 """
|
|
if datetime.tzinfo is None:
|
|
datetime = datetime.astimezone(timezone.utc)
|
|
res = datetime.isoformat()
|
|
if res.endswith('+00:00'):
|
|
return f'{res[:-6]}Z'
|
|
return res
|
|
|
|
fields.DateTime.SERIALIZATION_FUNCS['rfc3339'] = _rfc3339
|
|
fields.DateTime.DESERIALIZATION_FUNCS['rfc3339'] = fields.DateTime.DESERIALIZATION_FUNCS['iso']
|
|
fields.DateTime.DEFAULT_FORMAT = 'rfc3339'
|
|
|
|
class LazyStringField(fields.String):
|
|
""" Field that serializes a "false" value to the empty string
|
|
"""
|
|
|
|
def _serialize(self, value, attr, obj, **kwargs):
|
|
""" serialize None to the empty string
|
|
"""
|
|
return value if value else ''
|
|
|
|
class CommaSeparatedListField(fields.Raw):
|
|
""" Deserialize a string containing comma-separated values to
|
|
a list of strings
|
|
"""
|
|
|
|
default_error_messages = {
|
|
"invalid": "Not a valid string or list.",
|
|
"invalid_utf8": "Not a valid utf-8 string or list.",
|
|
}
|
|
|
|
def _deserialize(self, value, attr, data, **kwargs):
|
|
""" deserialize comma separated string to list of strings
|
|
"""
|
|
|
|
# empty
|
|
if not value:
|
|
return []
|
|
|
|
# handle list
|
|
if isinstance(value, list):
|
|
try:
|
|
value = [ensure_text_type(item) for item in value]
|
|
except UnicodeDecodeError as exc:
|
|
raise self.make_error("invalid_utf8") from exc
|
|
|
|
# handle text
|
|
else:
|
|
if not isinstance(value, (str, bytes)):
|
|
raise self.make_error("invalid")
|
|
try:
|
|
value = ensure_text_type(value)
|
|
except UnicodeDecodeError as exc:
|
|
raise self.make_error("invalid_utf8") from exc
|
|
else:
|
|
value = filter(bool, (item.strip() for item in value.split(',')))
|
|
|
|
return list(value)
|
|
|
|
|
|
class DkimKeyField(fields.String):
|
|
""" Serialize a dkim key to a multiline string and
|
|
deserialize a dkim key data as string or list of strings
|
|
to a valid dkim key
|
|
"""
|
|
|
|
default_error_messages = {
|
|
"invalid": "Not a valid string or list.",
|
|
"invalid_utf8": "Not a valid utf-8 string or list.",
|
|
}
|
|
|
|
def _serialize(self, value, attr, obj, **kwargs):
|
|
""" serialize dkim key as multiline string
|
|
"""
|
|
|
|
# map empty string and None to None
|
|
if not value:
|
|
return ''
|
|
|
|
# return multiline string
|
|
return _Multiline(value.decode('utf-8'))
|
|
|
|
def _wrap_key(self, begin, data, end):
|
|
""" generator to wrap key into RFC 7468 format """
|
|
yield begin
|
|
pos = 0
|
|
while pos < len(data):
|
|
yield data[pos:pos+64]
|
|
pos += 64
|
|
yield end
|
|
yield ''
|
|
|
|
def _deserialize(self, value, attr, data, **kwargs):
|
|
""" deserialize a string or list of strings to dkim key data
|
|
with verification
|
|
"""
|
|
|
|
# convert list to str
|
|
if isinstance(value, list):
|
|
try:
|
|
value = ''.join(ensure_text_type(item) for item in value).strip()
|
|
except UnicodeDecodeError as exc:
|
|
raise self.make_error("invalid_utf8") from exc
|
|
|
|
# only text is allowed
|
|
else:
|
|
if not isinstance(value, (str, bytes)):
|
|
raise self.make_error("invalid")
|
|
try:
|
|
value = ensure_text_type(value).strip()
|
|
except UnicodeDecodeError as exc:
|
|
raise self.make_error("invalid_utf8") from exc
|
|
|
|
# generate new key?
|
|
if value.lower() == '-generate-':
|
|
return dkim.gen_key()
|
|
|
|
# no key?
|
|
if not value:
|
|
return None
|
|
|
|
# remember part of value for ValidationError
|
|
bad_key = value
|
|
|
|
# strip header and footer, clean whitespace and wrap to 64 characters
|
|
try:
|
|
if value.startswith('-----BEGIN '):
|
|
end = value.index('-----', 11) + 5
|
|
header = value[:end]
|
|
value = value[end:]
|
|
else:
|
|
header = '-----BEGIN PRIVATE KEY-----'
|
|
|
|
if (pos := value.find('-----END ')) >= 0:
|
|
end = value.index('-----', pos+9) + 5
|
|
footer = value[pos:end]
|
|
value = value[:pos]
|
|
else:
|
|
footer = '-----END PRIVATE KEY-----'
|
|
except ValueError as exc:
|
|
raise ValidationError(f'invalid dkim key {bad_key!r}') from exc
|
|
|
|
# remove whitespace from key data
|
|
value = ''.join(value.split())
|
|
|
|
# remember part of value for ValidationError
|
|
bad_key = f'{value[:25]}...{value[-10:]}' if len(value) > 40 else value
|
|
|
|
# wrap key according to RFC 7468
|
|
value = ('\n'.join(self._wrap_key(header, value, footer))).encode('ascii')
|
|
|
|
# check key validity
|
|
try:
|
|
crypto.load_privatekey(crypto.FILETYPE_PEM, value)
|
|
except crypto.Error as exc:
|
|
raise ValidationError(f'invalid dkim key {bad_key!r}') from exc
|
|
else:
|
|
return value
|
|
|
|
class PasswordField(fields.Str):
|
|
""" Serialize a hashed password hash by stripping the obsolete {SCHEME}
|
|
Deserialize a plain password or hashed password into a hashed password
|
|
"""
|
|
|
|
_hashes = {'PBKDF2', 'BLF-CRYPT', 'SHA512-CRYPT', 'SHA256-CRYPT', 'MD5-CRYPT', 'CRYPT'}
|
|
|
|
def _serialize(self, value, attr, obj, **kwargs):
|
|
""" strip obsolete {password-hash} when serializing """
|
|
# strip scheme spec if in database - it's obsolete
|
|
if value.startswith('{') and (end := value.find('}', 1)) >= 0:
|
|
if value[1:end] in self._hashes:
|
|
return value[end+1:]
|
|
return value
|
|
|
|
def _deserialize(self, value, attr, data, **kwargs):
|
|
""" hashes plain password or checks hashed password
|
|
also strips obsolete {password-hash} when deserializing
|
|
"""
|
|
|
|
# when hashing is requested: use model instance to hash plain password
|
|
if data.get('hash_password'):
|
|
# hash password using model instance
|
|
inst = self.metadata['model']()
|
|
inst.set_password(value)
|
|
value = inst.password
|
|
del inst
|
|
|
|
# strip scheme spec when specified - it's obsolete
|
|
if value.startswith('{') and (end := value.find('}', 1)) >= 0:
|
|
if value[1:end] in self._hashes:
|
|
value = value[end+1:]
|
|
|
|
# check if algorithm is supported
|
|
inst = self.metadata['model'](password=value)
|
|
try:
|
|
# just check against empty string to see if hash is valid
|
|
inst.check_password('')
|
|
except ValueError as exc:
|
|
# ValueError: hash could not be identified
|
|
raise ValidationError(f'invalid password hash {value!r}') from exc
|
|
del inst
|
|
|
|
return value
|
|
|
|
|
|
### base schema ###
|
|
|
|
class Storage:
|
|
""" Storage class to save information in context
|
|
"""
|
|
|
|
context = {}
|
|
|
|
def _bind(self, key, bind):
|
|
if bind is True:
|
|
return (self.__class__, key)
|
|
if isinstance(bind, str):
|
|
return (get_schema(self.recall(bind).__class__), key)
|
|
return (bind, key)
|
|
|
|
def store(self, key, value, bind=None):
|
|
""" store value under key """
|
|
self.context.setdefault('_track', {})[self._bind(key, bind)]= value
|
|
|
|
def recall(self, key, bind=None):
|
|
""" recall value from key """
|
|
return self.context['_track'][self._bind(key, bind)]
|
|
|
|
class BaseOpts(SQLAlchemyAutoSchemaOpts):
|
|
""" Option class with sqla session
|
|
"""
|
|
def __init__(self, meta, ordered=False):
|
|
if not hasattr(meta, 'sqla_session'):
|
|
meta.sqla_session = models.db.session
|
|
if not hasattr(meta, 'sibling'):
|
|
meta.sibling = False
|
|
super(BaseOpts, self).__init__(meta, ordered=ordered)
|
|
|
|
class BaseSchema(ma.SQLAlchemyAutoSchema, Storage):
|
|
""" Marshmallow base schema with custom exclude logic
|
|
and option to hide sqla defaults
|
|
"""
|
|
|
|
OPTIONS_CLASS = BaseOpts
|
|
|
|
class Meta:
|
|
""" Schema config """
|
|
include_by_context = {}
|
|
exclude_by_value = {}
|
|
hide_by_context = {}
|
|
order = []
|
|
sibling = False
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
# prepare only to auto-include explicitly specified attributes
|
|
only = set(kwargs.get('only') or [])
|
|
|
|
# get context
|
|
context = kwargs.get('context', {})
|
|
flags = {key for key, value in context.items() if value is True}
|
|
|
|
# compile excludes
|
|
exclude = set(kwargs.get('exclude', []))
|
|
|
|
# always exclude
|
|
exclude.update({'created_at', 'updated_at'} - only)
|
|
|
|
# add include_by_context
|
|
if context is not None:
|
|
for need, what in getattr(self.Meta, 'include_by_context', {}).items():
|
|
if not flags & set(need):
|
|
exclude |= what - only
|
|
|
|
# update excludes
|
|
kwargs['exclude'] = exclude
|
|
|
|
# init SQLAlchemyAutoSchema
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# exclude_by_value
|
|
self._exclude_by_value = {
|
|
key: values for key, values in getattr(self.Meta, 'exclude_by_value', {}).items()
|
|
if key not in only
|
|
}
|
|
|
|
# exclude default values
|
|
if not context.get('full'):
|
|
for column in self.opts.model.__table__.columns:
|
|
if column.name not in exclude and column.name not in only:
|
|
self._exclude_by_value.setdefault(column.name, []).append(
|
|
None if column.default is None else column.default.arg
|
|
)
|
|
|
|
# hide by context
|
|
self._hide_by_context = set()
|
|
if context is not None:
|
|
for need, what in getattr(self.Meta, 'hide_by_context', {}).items():
|
|
if not flags & set(need):
|
|
self._hide_by_context |= what - only
|
|
|
|
# remember primary keys
|
|
self._primary = str(self.opts.model.__table__.primary_key.columns.values()[0].name)
|
|
|
|
# determine attribute order
|
|
if hasattr(self.Meta, 'order'):
|
|
# use user-defined order
|
|
order = self.Meta.order
|
|
else:
|
|
# default order is: primary_key + other keys alphabetically
|
|
order = list(sorted(self.fields.keys()))
|
|
if self._primary in order:
|
|
order.remove(self._primary)
|
|
order.insert(0, self._primary)
|
|
|
|
# order fieldlists
|
|
for fieldlist in (self.fields, self.load_fields, self.dump_fields):
|
|
for field in order:
|
|
if field in fieldlist:
|
|
fieldlist[field] = fieldlist.pop(field)
|
|
|
|
# move post_load hook "_add_instance" to the end (after load_instance mixin)
|
|
hooks = self._hooks[('post_load', False)]
|
|
hooks.remove('_add_instance')
|
|
hooks.append('_add_instance')
|
|
|
|
def hide(self, data):
|
|
""" helper method to hide input data for logging """
|
|
# always returns a copy of data
|
|
return {
|
|
key: HIDDEN if key in self._hide_by_context else deepcopy(value)
|
|
for key, value in data.items()
|
|
}
|
|
|
|
def _call_and_store(self, *args, **kwargs):
|
|
""" track current parent field for pruning """
|
|
self.store('field', kwargs['field_name'], True)
|
|
return super()._call_and_store(*args, **kwargs)
|
|
|
|
# this is only needed to work around the declared attr "email" primary key in model
|
|
def get_instance(self, data):
|
|
""" lookup item by defined primary key instead of key(s) from model """
|
|
if self.transient:
|
|
return None
|
|
if keys := getattr(self.Meta, 'primary_keys', None):
|
|
filters = {key: data.get(key) for key in keys}
|
|
if None not in filters.values():
|
|
res= self.session.query(self.opts.model).filter_by(**filters).first()
|
|
return res
|
|
res= super().get_instance(data)
|
|
return res
|
|
|
|
@pre_load(pass_many=True)
|
|
def _patch_many(self, items, many, **kwargs): # pylint: disable=unused-argument
|
|
""" - flush sqla session before serializing a section when requested
|
|
(make sure all objects that could be referred to later are created)
|
|
- when in update mode: patch input data before deserialization
|
|
- handle "prune" and "delete" items
|
|
- replace values in keys starting with '-' with default
|
|
"""
|
|
|
|
# flush sqla session
|
|
if not self.Meta.sibling:
|
|
self.opts.sqla_session.flush()
|
|
|
|
# stop early when not updating
|
|
if not self.context.get('update'):
|
|
return items
|
|
|
|
# patch "delete", "prune" and "default"
|
|
want_prune = []
|
|
def patch(count, data):
|
|
|
|
# don't allow __delete__ coming from input
|
|
if '__delete__' in data:
|
|
raise ValidationError('Unknown field.', f'{count}.__delete__')
|
|
|
|
# fail when hash_password is specified without password
|
|
if 'hash_password' in data and not 'password' in data:
|
|
raise ValidationError(
|
|
'Nothing to hash. Field "password" is missing.',
|
|
field_name = f'{count}.hash_password',
|
|
)
|
|
|
|
# handle "prune list" and "delete item" (-pkey: none and -pkey: id)
|
|
for key in data:
|
|
if key.startswith('-'):
|
|
if key[1:] == self._primary:
|
|
# delete or prune
|
|
if data[key] is None:
|
|
# prune
|
|
want_prune.append(True)
|
|
return None
|
|
# mark item for deletion
|
|
return {key[1:]: data[key], '__delete__': count}
|
|
|
|
# handle "set to default value" (-key: none)
|
|
def set_default(key, value):
|
|
if not key.startswith('-'):
|
|
return (key, value)
|
|
key = key[1:]
|
|
if not key in self.opts.model.__table__.columns:
|
|
return (key, None)
|
|
if value is not None:
|
|
raise ValidationError(
|
|
'Value must be "null" when resetting to default.',
|
|
f'{count}.{key}'
|
|
)
|
|
value = self.opts.model.__table__.columns[key].default
|
|
if value is None:
|
|
raise ValidationError(
|
|
'Field has no default value.',
|
|
f'{count}.{key}'
|
|
)
|
|
return (key, value.arg)
|
|
|
|
return dict(set_default(key, value) for key, value in data.items())
|
|
|
|
# convert items to "delete" and filter "prune" item
|
|
items = [
|
|
item for item in [
|
|
patch(count, item) for count, item in enumerate(items)
|
|
] if item
|
|
]
|
|
|
|
# remember if prune was requested for _prune_items@post_load
|
|
self.store('prune', bool(want_prune), True)
|
|
|
|
# remember original items to stabilize password-changes in _add_instance@post_load
|
|
self.store('original', items, True)
|
|
|
|
return items
|
|
|
|
@pre_load
|
|
def _patch_item(self, data, many, **kwargs): # pylint: disable=unused-argument
|
|
""" - call callback function to track import
|
|
- stabilize import of items with auto-increment primary key
|
|
- delete items
|
|
- delete/prune list attributes
|
|
- add missing required attributes
|
|
"""
|
|
|
|
# callback
|
|
if callback := self.context.get('callback'):
|
|
callback(self, data)
|
|
|
|
# stop early when not updating
|
|
if not self.opts.load_instance or not self.context.get('update'):
|
|
return data
|
|
|
|
# stabilize import of auto-increment primary keys (not required),
|
|
# by matching import data to existing items and setting primary key
|
|
if not self._primary in data:
|
|
for item in getattr(self.recall('parent'), self.recall('field', 'parent')):
|
|
existing = self.dump(item, many=False)
|
|
this = existing.pop(self._primary)
|
|
if data == existing:
|
|
instance = item
|
|
data[self._primary] = this
|
|
break
|
|
|
|
# try to load instance
|
|
instance = self.instance or self.get_instance(data)
|
|
if instance is None:
|
|
|
|
if '__delete__' in data:
|
|
# deletion of non-existent item requested
|
|
raise ValidationError(
|
|
f'Item to delete not found: {data[self._primary]!r}.',
|
|
field_name = f'{data["__delete__"]}.{self._primary}',
|
|
)
|
|
|
|
else:
|
|
|
|
if self.context.get('update'):
|
|
# remember instance as parent for pruning siblings
|
|
if not self.Meta.sibling:
|
|
self.store('parent', instance)
|
|
# delete instance from session when marked
|
|
if '__delete__' in data:
|
|
self.opts.sqla_session.delete(instance)
|
|
# delete item from lists or prune lists
|
|
# currently: domain.alternatives, user.forward_destination,
|
|
# user.manager_of, aliases.destination
|
|
for key, value in data.items():
|
|
if not isinstance(self.fields.get(key), (
|
|
RelatedList, CommaSeparatedListField, fields.Raw)
|
|
) or not isinstance(value, list):
|
|
continue
|
|
# deduplicate new value
|
|
new_value = set(value)
|
|
# handle list pruning
|
|
if '-prune-' in value:
|
|
value.remove('-prune-')
|
|
new_value.remove('-prune-')
|
|
else:
|
|
for old in getattr(instance, key):
|
|
# using str() is okay for now (see above)
|
|
new_value.add(str(old))
|
|
# handle item deletion
|
|
for item in value:
|
|
if item.startswith('-'):
|
|
new_value.remove(item)
|
|
try:
|
|
new_value.remove(item[1:])
|
|
except KeyError as exc:
|
|
raise ValidationError(
|
|
f'Item to delete not found: {item[1:]!r}.',
|
|
field_name=f'?.{key}',
|
|
) from exc
|
|
# sort list of new values
|
|
data[key] = sorted(new_value)
|
|
# log backref modification not catched by modify hook
|
|
if isinstance(self.fields[key], RelatedList):
|
|
if callback := self.context.get('callback'):
|
|
before = {str(v) for v in getattr(instance, key)}
|
|
after = set(data[key])
|
|
if before != after:
|
|
callback(self, instance, {
|
|
'key': key,
|
|
'target': str(instance),
|
|
'before': before,
|
|
'after': after,
|
|
})
|
|
|
|
# add attributes required for validation from db
|
|
for attr_name, field_obj in self.load_fields.items():
|
|
if field_obj.required and attr_name not in data:
|
|
data[attr_name] = getattr(instance, attr_name)
|
|
|
|
return data
|
|
|
|
@post_load(pass_many=True)
|
|
def _prune_items(self, items, many, **kwargs): # pylint: disable=unused-argument
|
|
""" handle list pruning """
|
|
|
|
# stop early when not updating
|
|
if not self.context.get('update'):
|
|
return items
|
|
|
|
# get prune flag from _patch_many@pre_load
|
|
want_prune = self.recall('prune', True)
|
|
|
|
# prune: determine if existing items in db need to be added or marked for deletion
|
|
add_items = False
|
|
del_items = False
|
|
if self.Meta.sibling:
|
|
# parent prunes automatically
|
|
if not want_prune:
|
|
# no prune requested => add old items
|
|
add_items = True
|
|
else:
|
|
# parent does not prune automatically
|
|
if want_prune:
|
|
# prune requested => mark old items for deletion
|
|
del_items = True
|
|
|
|
if add_items or del_items:
|
|
existing = {item[self._primary] for item in items if self._primary in item}
|
|
for item in getattr(self.recall('parent'), self.recall('field', 'parent')):
|
|
key = getattr(item, self._primary)
|
|
if key not in existing:
|
|
if add_items:
|
|
items.append({self._primary: key})
|
|
else:
|
|
items.append({self._primary: key, '__delete__': '?'})
|
|
|
|
return items
|
|
|
|
@post_load
|
|
def _add_instance(self, item, many, **kwargs): # pylint: disable=unused-argument
|
|
""" - undo password change in existing instances when plain password did not change
|
|
- add new instances to sqla session
|
|
"""
|
|
|
|
if not item in self.opts.sqla_session:
|
|
self.opts.sqla_session.add(item)
|
|
return item
|
|
|
|
# stop early when not updating or item has no password attribute
|
|
if not self.context.get('update') or not hasattr(item, 'password'):
|
|
return item
|
|
|
|
# did we hash a new plaintext password?
|
|
original = None
|
|
pkey = getattr(item, self._primary)
|
|
for data in self.recall('original', True):
|
|
if 'hash_password' in data and data.get(self._primary) == pkey:
|
|
original = data['password']
|
|
break
|
|
if original is None:
|
|
# password was hashed by us
|
|
return item
|
|
|
|
# reset hash if plain password matches hash from db
|
|
if attr := getattr(sqlalchemy.inspect(item).attrs, 'password', None):
|
|
if attr.history.has_changes() and attr.history.deleted:
|
|
try:
|
|
# reset password hash
|
|
inst = type(item)(password=attr.history.deleted[-1])
|
|
if inst.check_password(original):
|
|
item.password = inst.password
|
|
except ValueError:
|
|
# hash in db is invalid
|
|
pass
|
|
else:
|
|
del inst
|
|
|
|
return item
|
|
|
|
@post_dump
|
|
def _hide_values(self, data, many, **kwargs): # pylint: disable=unused-argument
|
|
""" hide secrets """
|
|
|
|
# stop early when not excluding/hiding
|
|
if not self._exclude_by_value and not self._hide_by_context:
|
|
return data
|
|
|
|
# exclude or hide values
|
|
full = self.context.get('full')
|
|
return type(data)(
|
|
(key, HIDDEN if key in self._hide_by_context else value)
|
|
for key, value in data.items()
|
|
if full or key not in self._exclude_by_value or value not in self._exclude_by_value[key]
|
|
)
|
|
|
|
# this field is used to mark items for deletion
|
|
mark_delete = fields.Boolean(data_key='__delete__', load_only=True)
|
|
|
|
# TODO: this can be removed when comment is not nullable in model
|
|
comment = LazyStringField()
|
|
|
|
|
|
### schema definitions ###
|
|
|
|
@mapped
|
|
class DomainSchema(BaseSchema):
|
|
""" Marshmallow schema for Domain model """
|
|
class Meta:
|
|
""" Schema config """
|
|
model = models.Domain
|
|
load_instance = True
|
|
include_relationships = True
|
|
exclude = ['users', 'managers', 'aliases']
|
|
|
|
include_by_context = {
|
|
('dns',): {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'},
|
|
}
|
|
hide_by_context = {
|
|
('secrets',): {'dkim_key'},
|
|
}
|
|
exclude_by_value = {
|
|
'alternatives': [[]],
|
|
'dkim_key': [None],
|
|
'dkim_publickey': [None],
|
|
'dns_mx': [None],
|
|
'dns_spf': [None],
|
|
'dns_dkim': [None],
|
|
'dns_dmarc': [None],
|
|
}
|
|
|
|
dkim_key = DkimKeyField(allow_none=True)
|
|
dkim_publickey = fields.String(dump_only=True)
|
|
dns_mx = fields.String(dump_only=True)
|
|
dns_spf = fields.String(dump_only=True)
|
|
dns_dkim = fields.String(dump_only=True)
|
|
dns_dmarc = fields.String(dump_only=True)
|
|
|
|
|
|
@mapped
|
|
class TokenSchema(BaseSchema):
|
|
""" Marshmallow schema for Token model """
|
|
class Meta:
|
|
""" Schema config """
|
|
model = models.Token
|
|
load_instance = True
|
|
|
|
sibling = True
|
|
|
|
password = PasswordField(required=True, metadata={'model': models.User})
|
|
hash_password = fields.Boolean(load_only=True, missing=False)
|
|
|
|
|
|
@mapped
|
|
class FetchSchema(BaseSchema):
|
|
""" Marshmallow schema for Fetch model """
|
|
class Meta:
|
|
""" Schema config """
|
|
model = models.Fetch
|
|
load_instance = True
|
|
|
|
sibling = True
|
|
include_by_context = {
|
|
('full', 'import'): {'last_check', 'error'},
|
|
}
|
|
hide_by_context = {
|
|
('secrets',): {'password'},
|
|
}
|
|
|
|
|
|
@mapped
|
|
class UserSchema(BaseSchema):
|
|
""" Marshmallow schema for User model """
|
|
class Meta:
|
|
""" Schema config """
|
|
model = models.User
|
|
load_instance = True
|
|
include_relationships = True
|
|
exclude = ['_email', 'domain', 'localpart', 'domain_name', 'quota_bytes_used']
|
|
|
|
primary_keys = ['email']
|
|
exclude_by_value = {
|
|
'forward_destination': [[]],
|
|
'tokens': [[]],
|
|
'fetches': [[]],
|
|
'manager_of': [[]],
|
|
'reply_enddate': ['2999-12-31'],
|
|
'reply_startdate': ['1900-01-01'],
|
|
}
|
|
|
|
email = fields.String(required=True)
|
|
tokens = fields.Nested(TokenSchema, many=True)
|
|
fetches = fields.Nested(FetchSchema, many=True)
|
|
|
|
password = PasswordField(required=True, metadata={'model': models.User})
|
|
hash_password = fields.Boolean(load_only=True, missing=False)
|
|
|
|
|
|
@mapped
|
|
class AliasSchema(BaseSchema):
|
|
""" Marshmallow schema for Alias model """
|
|
class Meta:
|
|
""" Schema config """
|
|
model = models.Alias
|
|
load_instance = True
|
|
exclude = ['_email', 'domain', 'localpart', 'domain_name']
|
|
|
|
primary_keys = ['email']
|
|
exclude_by_value = {
|
|
'destination': [[]],
|
|
}
|
|
|
|
email = fields.String(required=True)
|
|
destination = CommaSeparatedListField()
|
|
|
|
|
|
@mapped
|
|
class ConfigSchema(BaseSchema):
|
|
""" Marshmallow schema for Config model """
|
|
class Meta:
|
|
""" Schema config """
|
|
model = models.Config
|
|
load_instance = True
|
|
|
|
|
|
@mapped
|
|
class RelaySchema(BaseSchema):
|
|
""" Marshmallow schema for Relay model """
|
|
class Meta:
|
|
""" Schema config """
|
|
model = models.Relay
|
|
load_instance = True
|
|
|
|
|
|
@mapped
|
|
class MailuSchema(Schema, Storage):
|
|
""" Marshmallow schema for complete Mailu config """
|
|
class Meta:
|
|
""" Schema config """
|
|
model = models.MailuConfig
|
|
render_module = RenderYAML
|
|
|
|
order = ['domain', 'user', 'alias', 'relay'] # 'config'
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
# order fieldlists
|
|
for fieldlist in (self.fields, self.load_fields, self.dump_fields):
|
|
for field in self.Meta.order:
|
|
if field in fieldlist:
|
|
fieldlist[field] = fieldlist.pop(field)
|
|
|
|
def _call_and_store(self, *args, **kwargs):
|
|
""" track current parent and field for pruning """
|
|
self.store('field', kwargs['field_name'], True)
|
|
self.store('parent', self.context.get('config'))
|
|
return super()._call_and_store(*args, **kwargs)
|
|
|
|
@pre_load
|
|
def _clear_config(self, data, many, **kwargs): # pylint: disable=unused-argument
|
|
""" create config object in context if missing
|
|
and clear it if requested
|
|
"""
|
|
if 'config' not in self.context:
|
|
self.context['config'] = models.MailuConfig()
|
|
if self.context.get('clear'):
|
|
self.context['config'].clear(
|
|
models = {field.nested.opts.model for field in self.fields.values()}
|
|
)
|
|
return data
|
|
|
|
@post_load
|
|
def _make_config(self, data, many, **kwargs): # pylint: disable=unused-argument
|
|
""" update and return config object """
|
|
config = self.context['config']
|
|
for section in self.Meta.order:
|
|
if section in data:
|
|
config.update(data[section], section)
|
|
|
|
return config
|
|
|
|
domain = fields.Nested(DomainSchema, many=True)
|
|
user = fields.Nested(UserSchema, many=True)
|
|
alias = fields.Nested(AliasSchema, many=True)
|
|
relay = fields.Nested(RelaySchema, many=True)
|
|
# config = fields.Nested(ConfigSchema, many=True)
|