1
0
mirror of https://github.com/Mailu/Mailu.git synced 2025-01-18 03:21:36 +02:00
Mailu/core/admin/mailu/schemas.py

1275 lines
44 KiB
Python
Raw Normal View History

2021-01-14 02:11:04 +02:00
""" Mailu marshmallow fields and schema
2021-01-13 01:05:43 +02:00
"""
2021-01-14 02:11:04 +02:00
2021-02-15 01:46:59 +02:00
from copy import deepcopy
from collections import Counter
from datetime import timezone
2021-02-15 01:46:59 +02:00
import json
import logging
2021-01-13 01:05:43 +02:00
import yaml
2021-02-15 01:46:59 +02:00
import sqlalchemy
from marshmallow import pre_load, post_load, post_dump, fields, Schema
2021-02-15 01:46:59 +02:00
from marshmallow.utils import ensure_text_type
2021-01-14 02:11:04 +02:00
from marshmallow.exceptions import ValidationError
from marshmallow_sqlalchemy import SQLAlchemyAutoSchemaOpts
from marshmallow_sqlalchemy.fields import RelatedList
2021-02-15 01:46:59 +02:00
2021-01-13 01:05:43 +02:00
from flask_marshmallow import Marshmallow
2021-02-15 01:46:59 +02:00
2021-01-13 01:05:43 +02:00
from OpenSSL import crypto
2021-03-11 19:38:00 +02:00
from pygments import highlight
from pygments.token import Token
from pygments.lexers import get_lexer_by_name
from pygments.lexers.data import YamlLexer
from pygments.formatters import get_formatter_by_name
2021-02-15 01:46:59 +02:00
from mailu import models, dkim
2021-01-13 01:05:43 +02:00
ma = Marshmallow()
2021-01-14 02:11:04 +02:00
2021-02-15 01:46:59 +02:00
### import logging and schema colorization ###
_model2schema = {}
def get_schema(cls=None):
""" return schema class for model """
if cls is None:
return _model2schema.values()
return _model2schema.get(cls)
def mapped(cls):
""" register schema in model2schema map """
_model2schema[cls.Meta.model] = cls
return cls
class Logger:
""" helps with counting and colorizing
imported and exported data
"""
class MyYamlLexer(YamlLexer):
""" colorize yaml constants and integers """
def get_tokens(self, text, unfiltered=False):
for typ, value in super().get_tokens(text, unfiltered):
if typ is Token.Literal.Scalar.Plain:
if value in {'true', 'false', 'null'}:
typ = Token.Keyword.Constant
elif value == HIDDEN:
typ = Token.Error
else:
try:
int(value, 10)
except ValueError:
try:
float(value)
except ValueError:
pass
else:
typ = Token.Literal.Number.Float
else:
typ = Token.Literal.Number.Integer
yield typ, value
def __init__(self, want_color=None, can_color=False, debug=False, secrets=False):
self.lexer = 'yaml'
self.formatter = 'terminal'
self.strip = False
self.verbose = 0
self.quiet = False
self.secrets = secrets
self.debug = debug
self.print = print
2021-03-11 19:38:00 +02:00
self.color = want_color or can_color
self._counter = Counter()
self._schemas = {}
# log contexts
self._diff_context = {
'full': True,
'secrets': secrets,
}
log_context = {
'secrets': secrets,
}
# register listeners
for schema in get_schema():
model = schema.Meta.model
self._schemas[model] = schema(context=log_context)
sqlalchemy.event.listen(model, 'after_insert', self._listen_insert)
sqlalchemy.event.listen(model, 'after_update', self._listen_update)
sqlalchemy.event.listen(model, 'after_delete', self._listen_delete)
# special listener for dkim_key changes
# TODO: _listen_dkim can be removed when dkim keys are stored in database
self._dedupe_dkim = set()
sqlalchemy.event.listen(models.db.session, 'after_flush', self._listen_dkim)
# register debug logger for sqlalchemy
if self.debug:
logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
def _log(self, action, target, message=None):
if message is None:
try:
message = self._schemas[target.__class__].dump(target)
except KeyError:
message = target
if not isinstance(message, str):
message = repr(message)
self.print(f'{action} {target.__table__}: {self.colorize(message)}')
def _listen_insert(self, mapper, connection, target): # pylint: disable=unused-argument
""" callback method to track import """
self._counter.update([('Created', target.__table__.name)])
if self.verbose:
self._log('Created', target)
def _listen_update(self, mapper, connection, target): # pylint: disable=unused-argument
""" callback method to track import """
changes = {}
inspection = sqlalchemy.inspect(target)
for attr in sqlalchemy.orm.class_mapper(target.__class__).column_attrs:
history = getattr(inspection.attrs, attr.key).history
if history.has_changes() and history.deleted:
before = history.deleted[-1]
after = getattr(target, attr.key)
# we don't have ordered lists
if isinstance(before, list):
before = set(before)
if isinstance(after, list):
after = set(after)
# TODO: this can be removed when comment is not nullable in model
if attr.key == 'comment' and not before and not after:
pass
# only remember changed keys
elif before != after:
if self.verbose:
changes[str(attr.key)] = (before, after)
else:
break
if self.verbose:
# use schema to log changed attributes
schema = get_schema(target.__class__)
only = set(changes.keys()) & set(schema().fields.keys())
if only:
for key, value in schema(
only=only,
context=self._diff_context
).dump(target).items():
before, after = changes[key]
if value == HIDDEN:
before = HIDDEN if before else before
after = HIDDEN if after else after
else:
# also hide this
after = value
self._log('Modified', target, f'{str(target)!r} {key}: {before!r} -> {after!r}')
if changes:
self._counter.update([('Modified', target.__table__.name)])
def _listen_delete(self, mapper, connection, target): # pylint: disable=unused-argument
""" callback method to track import """
self._counter.update([('Deleted', target.__table__.name)])
if self.verbose:
self._log('Deleted', target)
# TODO: _listen_dkim can be removed when dkim keys are stored in database
def _listen_dkim(self, session, flush_context): # pylint: disable=unused-argument
""" callback method to track import """
for target in session.identity_map.values():
# look at Domains originally loaded from db
if not isinstance(target, models.Domain) or not target._sa_instance_state.load_path:
continue
before = target._dkim_key_on_disk
after = target._dkim_key
# "de-dupe" messages; this event is fired at every flush
if before == after or (target, before, after) in self._dedupe_dkim:
continue
self._dedupe_dkim.add((target, before, after))
self._counter.update([('Modified', target.__table__.name)])
if self.verbose:
if self.secrets:
before = before.decode('ascii', 'ignore')
after = after.decode('ascii', 'ignore')
else:
before = HIDDEN if before else ''
after = HIDDEN if after else ''
self._log('Modified', target, f'{str(target)!r} dkim_key: {before!r} -> {after!r}')
def track_serialize(self, obj, item, backref=None):
""" callback method to track import """
# called for backref modification?
if backref is not None:
self._log(
'Modified', item, '{target!r} {key}: {before!r} -> {after!r}'.format_map(backref))
return
# show input data?
if self.verbose < 2:
return
# hide secrets in data
if not self.secrets:
item = self._schemas[obj.opts.model].hide(item)
if 'hash_password' in item:
item['password'] = HIDDEN
if 'fetches' in item:
for fetch in item['fetches']:
fetch['password'] = HIDDEN
self._log('Handling', obj.opts.model, item)
def changes(self, *messages, **kwargs):
""" show changes gathered in counter """
if self.quiet:
return
if self._counter:
changes = []
last = None
for (action, what), count in sorted(self._counter.items()):
if action != last:
if last:
changes.append('/')
changes.append(f'{action}:')
last = action
changes.append(f'{what}({count})')
else:
changes = ['No changes.']
self.print(*messages, *changes, **kwargs)
def _format_errors(self, store, path=None):
res = []
if path is None:
path = []
for key in sorted(store):
location = path + [str(key)]
value = store[key]
if isinstance(value, dict):
res.extend(self._format_errors(value, location))
else:
for message in value:
res.append((".".join(location), message))
if path:
return res
maxlen = max(len(loc) for loc, msg in res)
res = [f' - {loc.ljust(maxlen)} : {msg}' for loc, msg in res]
errors = f'{len(res)} error{["s",""][len(res)==1]}'
res.insert(0, f'[ValidationError] {errors} occurred during input validation')
return '\n'.join(res)
def _is_validation_error(self, exc):
""" walk traceback to extract invalid field from marshmallow """
path = []
trace = exc.__traceback__
while trace:
if trace.tb_frame.f_code.co_name == '_serialize':
if 'attr' in trace.tb_frame.f_locals:
path.append(trace.tb_frame.f_locals['attr'])
elif trace.tb_frame.f_code.co_name == '_init_fields':
spec = ', '.join(
'.'.join(path + [key])
for key in trace.tb_frame.f_locals['invalid_fields'])
return f'Invalid filter: {spec}'
trace = trace.tb_next
return None
def format_exception(self, exc):
""" format ValidationErrors and other exceptions when not debugging """
if isinstance(exc, ValidationError):
return self._format_errors(exc.messages)
if isinstance(exc, ValueError):
if msg := self._is_validation_error(exc):
return msg
if self.debug:
return None
msg = ' '.join(str(exc).split())
return f'[{exc.__class__.__name__}] {msg}'
2021-02-15 01:46:59 +02:00
colorscheme = {
2021-02-15 01:46:59 +02:00
Token: ('', ''),
Token.Name.Tag: ('cyan', 'cyan'),
2021-02-15 01:46:59 +02:00
Token.Literal.Scalar: ('green', 'green'),
Token.Literal.String: ('green', 'green'),
Token.Name.Constant: ('green', 'green'), # multiline strings
Token.Keyword.Constant: ('magenta', 'magenta'),
Token.Literal.Number: ('magenta', 'magenta'),
Token.Error: ('red', 'red'),
Token.Name: ('red', 'red'),
Token.Operator: ('red', 'red'),
2021-02-15 01:46:59 +02:00
}
def colorize(self, data, lexer=None, formatter=None, color=None, strip=None):
""" add ANSI color to data """
2021-02-15 01:46:59 +02:00
if color is False or not self.color:
return data
lexer = lexer or self.lexer
lexer = Logger.MyYamlLexer() if lexer == 'yaml' else get_lexer_by_name(lexer)
formatter = get_formatter_by_name(formatter or self.formatter, colorscheme=self.colorscheme)
if strip is None:
strip = self.strip
res = highlight(data, lexer, formatter)
if strip:
return res.rstrip('\n')
return res
2021-02-15 01:46:59 +02:00
### marshmallow render modules ###
2021-02-15 01:46:59 +02:00
# hidden attributes
class _Hidden:
def __bool__(self):
return False
def __copy__(self):
return self
def __deepcopy__(self, _):
return self
def __eq__(self, other):
return str(other) == '<hidden>'
def __repr__(self):
return '<hidden>'
__str__ = __repr__
2021-01-13 01:05:43 +02:00
2021-02-15 01:46:59 +02:00
yaml.add_representer(
_Hidden,
lambda dumper, data: dumper.represent_data(str(data))
)
HIDDEN = _Hidden()
# multiline attributes
class _Multiline(str):
pass
yaml.add_representer(
_Multiline,
lambda dumper, data: dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')
2021-02-15 01:46:59 +02:00
)
# yaml render module
2021-01-13 01:05:43 +02:00
class RenderYAML:
""" Marshmallow YAML Render Module
"""
class SpacedDumper(yaml.Dumper):
""" YAML Dumper to add a newline between main sections
and double the indent used
"""
def write_line_break(self, data=None):
super().write_line_break(data)
if len(self.indents) == 1:
super().write_line_break()
def increase_indent(self, flow=False, indentless=False):
return super().increase_indent(flow, False)
@staticmethod
2021-02-15 01:46:59 +02:00
def _augment(kwargs, defaults):
""" add defaults to kwargs if missing
2021-01-13 01:05:43 +02:00
"""
2021-02-15 01:46:59 +02:00
for key, value in defaults.items():
if key not in kwargs:
kwargs[key] = value
2021-01-13 01:05:43 +02:00
_load_defaults = {}
@classmethod
def loads(cls, *args, **kwargs):
""" load yaml data from string
"""
2021-02-15 01:46:59 +02:00
cls._augment(kwargs, cls._load_defaults)
2021-01-14 02:11:04 +02:00
return yaml.safe_load(*args, **kwargs)
2021-01-13 01:05:43 +02:00
_dump_defaults = {
'Dumper': SpacedDumper,
'default_flow_style': False,
'allow_unicode': True,
'sort_keys': False,
2021-01-13 01:05:43 +02:00
}
@classmethod
def dumps(cls, *args, **kwargs):
2021-02-15 01:46:59 +02:00
""" dump data to yaml string
2021-01-13 01:05:43 +02:00
"""
2021-02-15 01:46:59 +02:00
cls._augment(kwargs, cls._dump_defaults)
2021-01-13 01:05:43 +02:00
return yaml.dump(*args, **kwargs)
# json encoder
2021-02-15 01:46:59 +02:00
class JSONEncoder(json.JSONEncoder):
""" JSONEncoder supporting serialization of HIDDEN """
def default(self, o):
""" serialize HIDDEN """
if isinstance(o, _Hidden):
return str(o)
return json.JSONEncoder.default(self, o)
# json render module
2021-02-15 01:46:59 +02:00
class RenderJSON:
""" Marshmallow JSON Render Module
"""
@staticmethod
def _augment(kwargs, defaults):
""" add defaults to kwargs if missing
2021-02-15 01:46:59 +02:00
"""
for key, value in defaults.items():
if key not in kwargs:
kwargs[key] = value
_load_defaults = {}
@classmethod
def loads(cls, *args, **kwargs):
""" load json data from string
"""
cls._augment(kwargs, cls._load_defaults)
return json.loads(*args, **kwargs)
2021-01-13 01:05:43 +02:00
2021-02-15 01:46:59 +02:00
_dump_defaults = {
'separators': (',',':'),
'cls': JSONEncoder,
}
@classmethod
def dumps(cls, *args, **kwargs):
""" dump data to json string
"""
cls._augment(kwargs, cls._dump_defaults)
return json.dumps(*args, **kwargs)
### marshmallow: custom fields ###
def _rfc3339(datetime):
""" dump datetime according to rfc3339 """
if datetime.tzinfo is None:
datetime = datetime.astimezone(timezone.utc)
res = datetime.isoformat()
if res.endswith('+00:00'):
return f'{res[:-6]}Z'
return res
fields.DateTime.SERIALIZATION_FUNCS['rfc3339'] = _rfc3339
fields.DateTime.DESERIALIZATION_FUNCS['rfc3339'] = fields.DateTime.DESERIALIZATION_FUNCS['iso']
fields.DateTime.DEFAULT_FORMAT = 'rfc3339'
2021-01-13 01:05:43 +02:00
2021-01-14 02:11:04 +02:00
class LazyStringField(fields.String):
2021-01-13 01:05:43 +02:00
""" Field that serializes a "false" value to the empty string
"""
def _serialize(self, value, attr, obj, **kwargs):
""" serialize None to the empty string
"""
return value if value else ''
2021-01-14 02:11:04 +02:00
class CommaSeparatedListField(fields.Raw):
2021-02-15 01:46:59 +02:00
""" Deserialize a string containing comma-separated values to
2021-01-13 01:05:43 +02:00
a list of strings
"""
default_error_messages = {
"invalid": "Not a valid string or list.",
"invalid_utf8": "Not a valid utf-8 string or list.",
}
2021-01-14 02:11:04 +02:00
def _deserialize(self, value, attr, data, **kwargs):
""" deserialize comma separated string to list of strings
"""
# empty
if not value:
return []
2021-01-13 01:05:43 +02:00
# handle list
if isinstance(value, list):
try:
value = [ensure_text_type(item) for item in value]
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
# handle text
2021-01-14 02:11:04 +02:00
else:
if not isinstance(value, (str, bytes)):
raise self.make_error("invalid")
try:
value = ensure_text_type(value)
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
else:
value = filter(bool, (item.strip() for item in value.split(',')))
return list(value)
2021-01-14 02:11:04 +02:00
class DkimKeyField(fields.String):
""" Serialize a dkim key to a multiline string and
deserialize a dkim key data as string or list of strings
to a valid dkim key
2021-01-13 01:05:43 +02:00
"""
2021-02-15 01:46:59 +02:00
default_error_messages = {
"invalid": "Not a valid string or list.",
"invalid_utf8": "Not a valid utf-8 string or list.",
}
2021-01-13 01:05:43 +02:00
def _serialize(self, value, attr, obj, **kwargs):
""" serialize dkim key as multiline string
2021-01-13 01:05:43 +02:00
"""
# map empty string and None to None
if not value:
return ''
# return multiline string
return _Multiline(value.decode('utf-8'))
2021-01-13 01:05:43 +02:00
def _wrap_key(self, begin, data, end):
""" generator to wrap key into RFC 7468 format """
yield begin
pos = 0
while pos < len(data):
yield data[pos:pos+64]
pos += 64
yield end
yield ''
2021-01-13 01:05:43 +02:00
def _deserialize(self, value, attr, data, **kwargs):
""" deserialize a string or list of strings to dkim key data
with verification
"""
# convert list to str
if isinstance(value, list):
2021-02-15 01:46:59 +02:00
try:
value = ''.join(ensure_text_type(item) for item in value).strip()
2021-02-15 01:46:59 +02:00
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
2021-01-13 01:05:43 +02:00
2021-02-15 01:46:59 +02:00
# only text is allowed
else:
if not isinstance(value, (str, bytes)):
raise self.make_error("invalid")
try:
value = ensure_text_type(value).strip()
2021-02-15 01:46:59 +02:00
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
2021-01-13 01:05:43 +02:00
# generate new key?
if value.lower() == '-generate-':
return dkim.gen_key()
2021-01-13 01:05:43 +02:00
# no key?
2021-01-13 01:05:43 +02:00
if not value:
return None
# remember part of value for ValidationError
bad_key = value
2021-01-14 02:11:04 +02:00
# strip header and footer, clean whitespace and wrap to 64 characters
try:
if value.startswith('-----BEGIN '):
end = value.index('-----', 11) + 5
header = value[:end]
value = value[end:]
else:
header = '-----BEGIN PRIVATE KEY-----'
if (pos := value.find('-----END ')) >= 0:
end = value.index('-----', pos+9) + 5
footer = value[pos:end]
value = value[:pos]
else:
footer = '-----END PRIVATE KEY-----'
except ValueError as exc:
raise ValidationError(f'invalid dkim key {bad_key!r}') from exc
# remove whitespace from key data
value = ''.join(value.split())
# remember part of value for ValidationError
bad_key = f'{value[:25]}...{value[-10:]}' if len(value) > 40 else value
# wrap key according to RFC 7468
value = ('\n'.join(self._wrap_key(header, value, footer))).encode('ascii')
# check key validity
2021-01-13 01:05:43 +02:00
try:
crypto.load_privatekey(crypto.FILETYPE_PEM, value)
except crypto.Error as exc:
raise ValidationError(f'invalid dkim key {bad_key!r}') from exc
2021-01-13 01:05:43 +02:00
else:
return value
2021-02-15 01:46:59 +02:00
class PasswordField(fields.Str):
""" Serialize a hashed password hash by stripping the obsolete {SCHEME}
Deserialize a plain password or hashed password into a hashed password
"""
2021-01-13 01:05:43 +02:00
2021-02-15 01:46:59 +02:00
_hashes = {'PBKDF2', 'BLF-CRYPT', 'SHA512-CRYPT', 'SHA256-CRYPT', 'MD5-CRYPT', 'CRYPT'}
2021-01-14 02:11:04 +02:00
2021-02-15 01:46:59 +02:00
def _serialize(self, value, attr, obj, **kwargs):
""" strip obsolete {password-hash} when serializing """
# strip scheme spec if in database - it's obsolete
if value.startswith('{') and (end := value.find('}', 1)) >= 0:
if value[1:end] in self._hashes:
return value[end+1:]
return value
2021-02-15 01:46:59 +02:00
def _deserialize(self, value, attr, data, **kwargs):
""" hashes plain password or checks hashed password
also strips obsolete {password-hash} when deserializing
"""
# when hashing is requested: use model instance to hash plain password
if data.get('hash_password'):
# hash password using model instance
inst = self.metadata['model']()
inst.set_password(value)
value = inst.password
del inst
# strip scheme spec when specified - it's obsolete
if value.startswith('{') and (end := value.find('}', 1)) >= 0:
if value[1:end] in self._hashes:
value = value[end+1:]
# check if algorithm is supported
inst = self.metadata['model'](password=value)
try:
# just check against empty string to see if hash is valid
inst.check_password('')
except ValueError as exc:
# ValueError: hash could not be identified
raise ValidationError(f'invalid password hash {value!r}') from exc
del inst
return value
2021-02-15 01:46:59 +02:00
### base schema ###
class Storage:
""" Storage class to save information in context
"""
context = {}
def _bind(self, key, bind):
if bind is True:
return (self.__class__, key)
if isinstance(bind, str):
return (get_schema(self.recall(bind).__class__), key)
return (bind, key)
def store(self, key, value, bind=None):
""" store value under key """
self.context.setdefault('_track', {})[self._bind(key, bind)]= value
def recall(self, key, bind=None):
""" recall value from key """
return self.context['_track'][self._bind(key, bind)]
2021-01-14 02:11:04 +02:00
class BaseOpts(SQLAlchemyAutoSchemaOpts):
""" Option class with sqla session
"""
def __init__(self, meta, ordered=False):
if not hasattr(meta, 'sqla_session'):
meta.sqla_session = models.db.session
2021-02-15 01:46:59 +02:00
if not hasattr(meta, 'sibling'):
meta.sibling = False
2021-01-14 02:11:04 +02:00
super(BaseOpts, self).__init__(meta, ordered=ordered)
class BaseSchema(ma.SQLAlchemyAutoSchema, Storage):
2021-01-13 01:05:43 +02:00
""" Marshmallow base schema with custom exclude logic
and option to hide sqla defaults
"""
2021-01-14 02:11:04 +02:00
OPTIONS_CLASS = BaseOpts
class Meta:
2021-01-13 01:05:43 +02:00
""" Schema config """
2021-02-15 01:46:59 +02:00
include_by_context = {}
exclude_by_value = {}
hide_by_context = {}
order = []
sibling = False
2021-01-13 01:05:43 +02:00
def __init__(self, *args, **kwargs):
# prepare only to auto-include explicitly specified attributes
only = set(kwargs.get('only') or [])
2021-02-15 01:46:59 +02:00
# get context
2021-01-13 01:05:43 +02:00
context = kwargs.get('context', {})
flags = {key for key, value in context.items() if value is True}
2021-01-13 01:05:43 +02:00
# compile excludes
exclude = set(kwargs.get('exclude', []))
2021-01-13 01:05:43 +02:00
# always exclude
exclude.update({'created_at', 'updated_at'} - only)
2021-01-13 01:05:43 +02:00
# add include_by_context
if context is not None:
2021-01-14 02:11:04 +02:00
for need, what in getattr(self.Meta, 'include_by_context', {}).items():
if not flags & set(need):
exclude |= what - only
2021-01-13 01:05:43 +02:00
# update excludes
kwargs['exclude'] = exclude
# init SQLAlchemyAutoSchema
super().__init__(*args, **kwargs)
2021-01-13 01:05:43 +02:00
# exclude_by_value
self._exclude_by_value = {
key: values for key, values in getattr(self.Meta, 'exclude_by_value', {}).items()
if key not in only
}
2021-01-13 01:05:43 +02:00
# exclude default values
if not context.get('full'):
2021-02-15 01:46:59 +02:00
for column in self.opts.model.__table__.columns:
if column.name not in exclude and column.name not in only:
2021-01-13 01:05:43 +02:00
self._exclude_by_value.setdefault(column.name, []).append(
None if column.default is None else column.default.arg
)
# hide by context
self._hide_by_context = set()
if context is not None:
2021-01-14 02:11:04 +02:00
for need, what in getattr(self.Meta, 'hide_by_context', {}).items():
if not flags & set(need):
self._hide_by_context |= what - only
2021-02-15 01:46:59 +02:00
# remember primary keys
self._primary = str(self.opts.model.__table__.primary_key.columns.values()[0].name)
2021-02-15 01:46:59 +02:00
# determine attribute order
if hasattr(self.Meta, 'order'):
# use user-defined order
order = self.Meta.order
else:
# default order is: primary_key + other keys alphabetically
order = list(sorted(self.fields.keys()))
if self._primary in order:
order.remove(self._primary)
order.insert(0, self._primary)
# order fieldlists
for fieldlist in (self.fields, self.load_fields, self.dump_fields):
for field in order:
if field in fieldlist:
fieldlist[field] = fieldlist.pop(field)
2021-02-15 01:46:59 +02:00
# move post_load hook "_add_instance" to the end (after load_instance mixin)
hooks = self._hooks[('post_load', False)]
2021-02-15 01:46:59 +02:00
hooks.remove('_add_instance')
hooks.append('_add_instance')
def hide(self, data):
""" helper method to hide input data for logging """
# always returns a copy of data
return {
key: HIDDEN if key in self._hide_by_context else deepcopy(value)
for key, value in data.items()
}
def _call_and_store(self, *args, **kwargs):
""" track current parent field for pruning """
self.store('field', kwargs['field_name'], True)
2021-02-15 01:46:59 +02:00
return super()._call_and_store(*args, **kwargs)
# this is only needed to work around the declared attr "email" primary key in model
def get_instance(self, data):
""" lookup item by defined primary key instead of key(s) from model """
if self.transient:
return None
if keys := getattr(self.Meta, 'primary_keys', None):
filters = {key: data.get(key) for key in keys}
if None not in filters.values():
res= self.session.query(self.opts.model).filter_by(**filters).first()
return res
res= super().get_instance(data)
return res
2021-02-15 01:46:59 +02:00
@pre_load(pass_many=True)
def _patch_many(self, items, many, **kwargs): # pylint: disable=unused-argument
2021-02-15 01:46:59 +02:00
""" - flush sqla session before serializing a section when requested
(make sure all objects that could be referred to later are created)
- when in update mode: patch input data before deserialization
- handle "prune" and "delete" items
- replace values in keys starting with '-' with default
"""
# flush sqla session
if not self.Meta.sibling:
self.opts.sqla_session.flush()
# stop early when not updating
if not self.context.get('update'):
return items
# patch "delete", "prune" and "default"
want_prune = []
def patch(count, data):
2021-02-15 01:46:59 +02:00
# don't allow __delete__ coming from input
if '__delete__' in data:
raise ValidationError('Unknown field.', f'{count}.__delete__')
# fail when hash_password is specified without password
if 'hash_password' in data and not 'password' in data:
raise ValidationError(
'Nothing to hash. Field "password" is missing.',
field_name = f'{count}.hash_password',
)
2021-02-15 01:46:59 +02:00
# handle "prune list" and "delete item" (-pkey: none and -pkey: id)
for key in data:
if key.startswith('-'):
if key[1:] == self._primary:
# delete or prune
if data[key] is None:
# prune
want_prune.append(True)
2021-02-15 01:46:59 +02:00
return None
# mark item for deletion
return {key[1:]: data[key], '__delete__': count}
2021-02-15 01:46:59 +02:00
# handle "set to default value" (-key: none)
def set_default(key, value):
if not key.startswith('-'):
return (key, value)
key = key[1:]
if not key in self.opts.model.__table__.columns:
return (key, None)
if value is not None:
raise ValidationError(
'Value must be "null" when resetting to default.',
2021-02-15 01:46:59 +02:00
f'{count}.{key}'
)
value = self.opts.model.__table__.columns[key].default
if value is None:
raise ValidationError(
'Field has no default value.',
f'{count}.{key}'
)
return (key, value.arg)
return dict(set_default(key, value) for key, value in data.items())
2021-02-15 01:46:59 +02:00
# convert items to "delete" and filter "prune" item
items = [
item for item in [
patch(count, item) for count, item in enumerate(items)
2021-02-15 01:46:59 +02:00
] if item
]
# remember if prune was requested for _prune_items@post_load
self.store('prune', bool(want_prune), True)
2021-02-15 01:46:59 +02:00
# remember original items to stabilize password-changes in _add_instance@post_load
self.store('original', items, True)
2021-02-15 01:46:59 +02:00
return items
@pre_load
def _patch_item(self, data, many, **kwargs): # pylint: disable=unused-argument
""" - call callback function to track import
- stabilize import of items with auto-increment primary key
- delete items
- delete/prune list attributes
- add missing required attributes
2021-02-15 01:46:59 +02:00
"""
2021-02-15 01:46:59 +02:00
# callback
if callback := self.context.get('callback'):
callback(self, data)
2021-02-15 01:46:59 +02:00
# stop early when not updating
2021-02-15 01:46:59 +02:00
if not self.opts.load_instance or not self.context.get('update'):
return data
# stabilize import of auto-increment primary keys (not required),
# by matching import data to existing items and setting primary key
if not self._primary in data:
for item in getattr(self.recall('parent'), self.recall('field', 'parent')):
2021-02-15 01:46:59 +02:00
existing = self.dump(item, many=False)
this = existing.pop(self._primary)
if data == existing:
instance = item
data[self._primary] = this
break
# try to load instance
instance = self.instance or self.get_instance(data)
if instance is None:
if '__delete__' in data:
# deletion of non-existent item requested
raise ValidationError(
f'Item to delete not found: {data[self._primary]!r}.',
field_name = f'{data["__delete__"]}.{self._primary}',
2021-02-15 01:46:59 +02:00
)
else:
if self.context.get('update'):
# remember instance as parent for pruning siblings
if not self.Meta.sibling:
self.store('parent', instance)
# delete instance from session when marked
2021-02-15 01:46:59 +02:00
if '__delete__' in data:
self.opts.sqla_session.delete(instance)
# delete item from lists or prune lists
# currently: domain.alternatives, user.forward_destination,
# user.manager_of, aliases.destination
for key, value in data.items():
if not isinstance(self.fields.get(key), (
RelatedList, CommaSeparatedListField, fields.Raw)
) or not isinstance(value, list):
continue
# deduplicate new value
new_value = set(value)
# handle list pruning
if '-prune-' in value:
value.remove('-prune-')
new_value.remove('-prune-')
else:
for old in getattr(instance, key):
# using str() is okay for now (see above)
new_value.add(str(old))
# handle item deletion
for item in value:
if item.startswith('-'):
new_value.remove(item)
try:
new_value.remove(item[1:])
except KeyError as exc:
raise ValidationError(
f'Item to delete not found: {item[1:]!r}.',
field_name=f'?.{key}',
) from exc
# sort list of new values
data[key] = sorted(new_value)
# log backref modification not catched by modify hook
if isinstance(self.fields[key], RelatedList):
if callback := self.context.get('callback'):
before = {str(v) for v in getattr(instance, key)}
after = set(data[key])
if before != after:
callback(self, instance, {
'key': key,
'target': str(instance),
'before': before,
'after': after,
})
2021-02-15 01:46:59 +02:00
# add attributes required for validation from db
for attr_name, field_obj in self.load_fields.items():
if field_obj.required and attr_name not in data:
data[attr_name] = getattr(instance, attr_name)
return data
@post_load(pass_many=True)
def _prune_items(self, items, many, **kwargs): # pylint: disable=unused-argument
""" handle list pruning """
# stop early when not updating
if not self.context.get('update'):
return items
# get prune flag from _patch_many@pre_load
want_prune = self.recall('prune', True)
# prune: determine if existing items in db need to be added or marked for deletion
add_items = False
del_items = False
if self.Meta.sibling:
# parent prunes automatically
if not want_prune:
# no prune requested => add old items
add_items = True
2021-02-15 01:46:59 +02:00
else:
# parent does not prune automatically
if want_prune:
# prune requested => mark old items for deletion
del_items = True
if add_items or del_items:
existing = {item[self._primary] for item in items if self._primary in item}
for item in getattr(self.recall('parent'), self.recall('field', 'parent')):
key = getattr(item, self._primary)
if key not in existing:
if add_items:
items.append({self._primary: key})
else:
items.append({self._primary: key, '__delete__': '?'})
return items
@post_load
def _add_instance(self, item, many, **kwargs): # pylint: disable=unused-argument
""" - undo password change in existing instances when plain password did not change
- add new instances to sqla session
"""
if not item in self.opts.sqla_session:
2021-02-15 01:46:59 +02:00
self.opts.sqla_session.add(item)
return item
# stop early when not updating or item has no password attribute
if not self.context.get('update') or not hasattr(item, 'password'):
return item
# did we hash a new plaintext password?
original = None
pkey = getattr(item, self._primary)
for data in self.recall('original', True):
if 'hash_password' in data and data.get(self._primary) == pkey:
original = data['password']
break
if original is None:
# password was hashed by us
return item
# reset hash if plain password matches hash from db
if attr := getattr(sqlalchemy.inspect(item).attrs, 'password', None):
if attr.history.has_changes() and attr.history.deleted:
try:
# reset password hash
inst = type(item)(password=attr.history.deleted[-1])
if inst.check_password(original):
item.password = inst.password
except ValueError:
# hash in db is invalid
pass
else:
del inst
return item
2021-01-13 01:05:43 +02:00
@post_dump
def _hide_values(self, data, many, **kwargs): # pylint: disable=unused-argument
""" hide secrets """
2021-01-13 01:05:43 +02:00
# stop early when not excluding/hiding
2021-01-13 01:05:43 +02:00
if not self._exclude_by_value and not self._hide_by_context:
return data
2021-02-15 01:46:59 +02:00
# exclude or hide values
2021-01-13 01:05:43 +02:00
full = self.context.get('full')
return type(data)(
2021-02-15 01:46:59 +02:00
(key, HIDDEN if key in self._hide_by_context else value)
2021-01-13 01:05:43 +02:00
for key, value in data.items()
if full or key not in self._exclude_by_value or value not in self._exclude_by_value[key]
)
2021-02-15 01:46:59 +02:00
# this field is used to mark items for deletion
mark_delete = fields.Boolean(data_key='__delete__', load_only=True)
# TODO: this can be removed when comment is not nullable in model
2021-01-14 02:11:04 +02:00
comment = LazyStringField()
### schema definitions ###
2021-01-13 01:05:43 +02:00
@mapped
class DomainSchema(BaseSchema):
2021-01-13 01:05:43 +02:00
""" Marshmallow schema for Domain model """
class Meta:
""" Schema config """
model = models.Domain
2021-01-14 02:11:04 +02:00
load_instance = True
2021-01-13 01:05:43 +02:00
include_relationships = True
exclude = ['users', 'managers', 'aliases']
include_by_context = {
2021-01-14 02:11:04 +02:00
('dns',): {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'},
2021-01-13 01:05:43 +02:00
}
hide_by_context = {
2021-01-14 02:11:04 +02:00
('secrets',): {'dkim_key'},
2021-01-13 01:05:43 +02:00
}
exclude_by_value = {
'alternatives': [[]],
'dkim_key': [None],
'dkim_publickey': [None],
'dns_mx': [None],
'dns_spf': [None],
'dns_dkim': [None],
'dns_dmarc': [None],
}
2021-01-14 02:11:04 +02:00
dkim_key = DkimKeyField(allow_none=True)
2021-01-13 01:05:43 +02:00
dkim_publickey = fields.String(dump_only=True)
dns_mx = fields.String(dump_only=True)
dns_spf = fields.String(dump_only=True)
dns_dkim = fields.String(dump_only=True)
dns_dmarc = fields.String(dump_only=True)
2021-01-13 01:05:43 +02:00
@mapped
2021-01-13 01:05:43 +02:00
class TokenSchema(BaseSchema):
""" Marshmallow schema for Token model """
class Meta:
""" Schema config """
model = models.Token
2021-01-14 02:11:04 +02:00
load_instance = True
2021-01-13 01:05:43 +02:00
2021-02-15 01:46:59 +02:00
sibling = True
password = PasswordField(required=True, metadata={'model': models.User})
hash_password = fields.Boolean(load_only=True, missing=False)
2021-01-13 01:05:43 +02:00
@mapped
2021-01-13 01:05:43 +02:00
class FetchSchema(BaseSchema):
""" Marshmallow schema for Fetch model """
class Meta:
""" Schema config """
model = models.Fetch
2021-01-14 02:11:04 +02:00
load_instance = True
2021-02-15 01:46:59 +02:00
sibling = True
2021-01-13 01:05:43 +02:00
include_by_context = {
2021-01-14 02:11:04 +02:00
('full', 'import'): {'last_check', 'error'},
2021-01-13 01:05:43 +02:00
}
hide_by_context = {
2021-01-14 02:11:04 +02:00
('secrets',): {'password'},
2021-01-13 01:05:43 +02:00
}
@mapped
class UserSchema(BaseSchema):
2021-01-13 01:05:43 +02:00
""" Marshmallow schema for User model """
class Meta:
""" Schema config """
model = models.User
2021-01-14 02:11:04 +02:00
load_instance = True
2021-01-13 01:05:43 +02:00
include_relationships = True
2021-02-15 01:46:59 +02:00
exclude = ['_email', 'domain', 'localpart', 'domain_name', 'quota_bytes_used']
2021-01-13 01:05:43 +02:00
2021-02-15 01:46:59 +02:00
primary_keys = ['email']
2021-01-13 01:05:43 +02:00
exclude_by_value = {
'forward_destination': [[]],
2021-02-15 01:46:59 +02:00
'tokens': [[]],
'fetches': [[]],
'manager_of': [[]],
'reply_enddate': ['2999-12-31'],
'reply_startdate': ['1900-01-01'],
2021-01-13 01:05:43 +02:00
}
2021-02-15 01:46:59 +02:00
email = fields.String(required=True)
2021-01-13 01:05:43 +02:00
tokens = fields.Nested(TokenSchema, many=True)
fetches = fields.Nested(FetchSchema, many=True)
2021-02-15 01:46:59 +02:00
password = PasswordField(required=True, metadata={'model': models.User})
hash_password = fields.Boolean(load_only=True, missing=False)
@mapped
class AliasSchema(BaseSchema):
2021-01-13 01:05:43 +02:00
""" Marshmallow schema for Alias model """
class Meta:
""" Schema config """
model = models.Alias
2021-01-14 02:11:04 +02:00
load_instance = True
2021-02-15 01:46:59 +02:00
exclude = ['_email', 'domain', 'localpart', 'domain_name']
2021-01-13 01:05:43 +02:00
2021-02-15 01:46:59 +02:00
primary_keys = ['email']
2021-01-13 01:05:43 +02:00
exclude_by_value = {
'destination': [[]],
}
2021-02-15 01:46:59 +02:00
email = fields.String(required=True)
2021-01-14 02:11:04 +02:00
destination = CommaSeparatedListField()
@mapped
class ConfigSchema(BaseSchema):
2021-01-13 01:05:43 +02:00
""" Marshmallow schema for Config model """
class Meta:
""" Schema config """
model = models.Config
2021-01-14 02:11:04 +02:00
load_instance = True
@mapped
class RelaySchema(BaseSchema):
2021-01-13 01:05:43 +02:00
""" Marshmallow schema for Relay model """
class Meta:
""" Schema config """
model = models.Relay
2021-01-14 02:11:04 +02:00
load_instance = True
@mapped
class MailuSchema(Schema, Storage):
2021-01-14 02:11:04 +02:00
""" Marshmallow schema for complete Mailu config """
2021-01-13 01:05:43 +02:00
class Meta:
""" Schema config """
model = models.MailuConfig
2021-01-13 01:05:43 +02:00
render_module = RenderYAML
2021-02-15 01:46:59 +02:00
order = ['domain', 'user', 'alias', 'relay'] # 'config'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# order fieldlists
for fieldlist in (self.fields, self.load_fields, self.dump_fields):
for field in self.Meta.order:
if field in fieldlist:
fieldlist[field] = fieldlist.pop(field)
2021-02-15 01:46:59 +02:00
def _call_and_store(self, *args, **kwargs):
""" track current parent and field for pruning """
self.store('field', kwargs['field_name'], True)
self.store('parent', self.context.get('config'))
2021-02-15 01:46:59 +02:00
return super()._call_and_store(*args, **kwargs)
@pre_load
def _clear_config(self, data, many, **kwargs): # pylint: disable=unused-argument
""" create config object in context if missing
and clear it if requested
"""
if 'config' not in self.context:
self.context['config'] = models.MailuConfig()
if self.context.get('clear'):
self.context['config'].clear(
models = {field.nested.opts.model for field in self.fields.values()}
)
return data
2021-01-14 02:11:04 +02:00
@post_load
def _make_config(self, data, many, **kwargs): # pylint: disable=unused-argument
""" update and return config object """
config = self.context['config']
for section in self.Meta.order:
if section in data:
config.update(data[section], section)
return config
2021-02-15 01:46:59 +02:00
domain = fields.Nested(DomainSchema, many=True)
user = fields.Nested(UserSchema, many=True)
alias = fields.Nested(AliasSchema, many=True)
relay = fields.Nested(RelaySchema, many=True)
# config = fields.Nested(ConfigSchema, many=True)