1
0
mirror of https://github.com/Mailu/Mailu.git synced 2024-12-20 20:22:38 +02:00
Mailu/core/admin/mailu/schemas.py
2021-07-03 11:40:32 +02:00

1270 lines
44 KiB
Python

""" Mailu marshmallow fields and schema
"""
from copy import deepcopy
from collections import Counter
from datetime import timezone
import json
import logging
import yaml
import sqlalchemy
from marshmallow import pre_load, post_load, post_dump, fields, Schema
from marshmallow.utils import ensure_text_type
from marshmallow.exceptions import ValidationError
from marshmallow_sqlalchemy import SQLAlchemyAutoSchemaOpts
from marshmallow_sqlalchemy.fields import RelatedList
from flask_marshmallow import Marshmallow
from OpenSSL import crypto
from pygments import highlight
from pygments.token import Token
from pygments.lexers import get_lexer_by_name
from pygments.lexers.data import YamlLexer
from pygments.formatters import get_formatter_by_name
from mailu import models, dkim
ma = Marshmallow()
### import logging and schema colorization ###
_model2schema = {}
def get_schema(cls=None):
""" return schema class for model """
if cls is None:
return _model2schema.values()
return _model2schema.get(cls)
def mapped(cls):
""" register schema in model2schema map """
_model2schema[cls.Meta.model] = cls
return cls
class Logger:
""" helps with counting and colorizing
imported and exported data
"""
class MyYamlLexer(YamlLexer):
""" colorize yaml constants and integers """
def get_tokens(self, text, unfiltered=False):
for typ, value in super().get_tokens(text, unfiltered):
if typ is Token.Literal.Scalar.Plain:
if value in {'true', 'false', 'null'}:
typ = Token.Keyword.Constant
elif value == HIDDEN:
typ = Token.Error
else:
try:
int(value, 10)
except ValueError:
try:
float(value)
except ValueError:
pass
else:
typ = Token.Literal.Number.Float
else:
typ = Token.Literal.Number.Integer
yield typ, value
def __init__(self, want_color=None, can_color=False, debug=False, secrets=False):
self.lexer = 'yaml'
self.formatter = 'terminal'
self.strip = False
self.verbose = 0
self.quiet = False
self.secrets = secrets
self.debug = debug
self.print = print
self.color = want_color or can_color
self._counter = Counter()
self._schemas = {}
# log contexts
self._diff_context = {
'full': True,
'secrets': secrets,
}
log_context = {
'secrets': secrets,
}
# register listeners
for schema in get_schema():
model = schema.Meta.model
self._schemas[model] = schema(context=log_context)
sqlalchemy.event.listen(model, 'after_insert', self._listen_insert)
sqlalchemy.event.listen(model, 'after_update', self._listen_update)
sqlalchemy.event.listen(model, 'after_delete', self._listen_delete)
# special listener for dkim_key changes
# TODO: _listen_dkim can be removed when dkim keys are stored in database
self._dedupe_dkim = set()
sqlalchemy.event.listen(models.db.session, 'after_flush', self._listen_dkim)
# register debug logger for sqlalchemy
if self.debug:
logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
def _log(self, action, target, message=None):
if message is None:
try:
message = self._schemas[target.__class__].dump(target)
except KeyError:
message = target
if not isinstance(message, str):
message = repr(message)
self.print(f'{action} {target.__table__}: {self.colorize(message)}')
def _listen_insert(self, mapper, connection, target): # pylint: disable=unused-argument
""" callback method to track import """
self._counter.update([('Created', target.__table__.name)])
if self.verbose:
self._log('Created', target)
def _listen_update(self, mapper, connection, target): # pylint: disable=unused-argument
""" callback method to track import """
changes = {}
inspection = sqlalchemy.inspect(target)
for attr in sqlalchemy.orm.class_mapper(target.__class__).column_attrs:
history = getattr(inspection.attrs, attr.key).history
if history.has_changes() and history.deleted:
before = history.deleted[-1]
after = getattr(target, attr.key)
# TODO: this can be removed when comment is not nullable in model
if attr.key == 'comment' and not before and not after:
pass
# only remember changed keys
elif before != after:
if self.verbose:
changes[str(attr.key)] = (before, after)
else:
break
if self.verbose:
# use schema to log changed attributes
schema = get_schema(target.__class__)
only = set(changes.keys()) & set(schema().fields.keys())
if only:
for key, value in schema(
only=only,
context=self._diff_context
).dump(target).items():
before, after = changes[key]
if value == HIDDEN:
before = HIDDEN if before else before
after = HIDDEN if after else after
else:
# also hide this
after = value
self._log('Modified', target, f'{str(target)!r} {key}: {before!r} -> {after!r}')
if changes:
self._counter.update([('Modified', target.__table__.name)])
def _listen_delete(self, mapper, connection, target): # pylint: disable=unused-argument
""" callback method to track import """
self._counter.update([('Deleted', target.__table__.name)])
if self.verbose:
self._log('Deleted', target)
# TODO: _listen_dkim can be removed when dkim keys are stored in database
def _listen_dkim(self, session, flush_context): # pylint: disable=unused-argument
""" callback method to track import """
for target in session.identity_map.values():
# look at Domains originally loaded from db
if not isinstance(target, models.Domain) or not target._sa_instance_state.load_path:
continue
before = target._dkim_key_on_disk
after = target._dkim_key
# "de-dupe" messages; this event is fired at every flush
if before == after or (target, before, after) in self._dedupe_dkim:
continue
self._dedupe_dkim.add((target, before, after))
self._counter.update([('Modified', target.__table__.name)])
if self.verbose:
if self.secrets:
before = before.decode('ascii', 'ignore')
after = after.decode('ascii', 'ignore')
else:
before = HIDDEN if before else ''
after = HIDDEN if after else ''
self._log('Modified', target, f'{str(target)!r} dkim_key: {before!r} -> {after!r}')
def track_serialize(self, obj, item, backref=None):
""" callback method to track import """
# called for backref modification?
if backref is not None:
self._log(
'Modified', item, '{target!r} {key}: {before!r} -> {after!r}'.format_map(backref))
return
# show input data?
if self.verbose < 2:
return
# hide secrets in data
if not self.secrets:
item = self._schemas[obj.opts.model].hide(item)
if 'hash_password' in item:
item['password'] = HIDDEN
if 'fetches' in item:
for fetch in item['fetches']:
fetch['password'] = HIDDEN
self._log('Handling', obj.opts.model, item)
def changes(self, *messages, **kwargs):
""" show changes gathered in counter """
if self.quiet:
return
if self._counter:
changes = []
last = None
for (action, what), count in sorted(self._counter.items()):
if action != last:
if last:
changes.append('/')
changes.append(f'{action}:')
last = action
changes.append(f'{what}({count})')
else:
changes = ['No changes.']
self.print(*messages, *changes, **kwargs)
def _format_errors(self, store, path=None):
res = []
if path is None:
path = []
for key in sorted(store):
location = path + [str(key)]
value = store[key]
if isinstance(value, dict):
res.extend(self._format_errors(value, location))
else:
for message in value:
res.append((".".join(location), message))
if path:
return res
maxlen = max(len(loc) for loc, msg in res)
res = [f' - {loc.ljust(maxlen)} : {msg}' for loc, msg in res]
errors = f'{len(res)} error{["s",""][len(res)==1]}'
res.insert(0, f'[ValidationError] {errors} occurred during input validation')
return '\n'.join(res)
def _is_validation_error(self, exc):
""" walk traceback to extract invalid field from marshmallow """
path = []
trace = exc.__traceback__
while trace:
if trace.tb_frame.f_code.co_name == '_serialize':
if 'attr' in trace.tb_frame.f_locals:
path.append(trace.tb_frame.f_locals['attr'])
elif trace.tb_frame.f_code.co_name == '_init_fields':
spec = ', '.join(
'.'.join(path + [key])
for key in trace.tb_frame.f_locals['invalid_fields'])
return f'Invalid filter: {spec}'
trace = trace.tb_next
return None
def format_exception(self, exc):
""" format ValidationErrors and other exceptions when not debugging """
if isinstance(exc, ValidationError):
return self._format_errors(exc.messages)
if isinstance(exc, ValueError):
if msg := self._is_validation_error(exc):
return msg
if self.debug:
return None
msg = ' '.join(str(exc).split())
return f'[{exc.__class__.__name__}] {msg}'
colorscheme = {
Token: ('', ''),
Token.Name.Tag: ('cyan', 'cyan'),
Token.Literal.Scalar: ('green', 'green'),
Token.Literal.String: ('green', 'green'),
Token.Name.Constant: ('green', 'green'), # multiline strings
Token.Keyword.Constant: ('magenta', 'magenta'),
Token.Literal.Number: ('magenta', 'magenta'),
Token.Error: ('red', 'red'),
Token.Name: ('red', 'red'),
Token.Operator: ('red', 'red'),
}
def colorize(self, data, lexer=None, formatter=None, color=None, strip=None):
""" add ANSI color to data """
if color is False or not self.color:
return data
lexer = lexer or self.lexer
lexer = Logger.MyYamlLexer() if lexer == 'yaml' else get_lexer_by_name(lexer)
formatter = get_formatter_by_name(formatter or self.formatter, colorscheme=self.colorscheme)
if strip is None:
strip = self.strip
res = highlight(data, lexer, formatter)
if strip:
return res.rstrip('\n')
return res
### marshmallow render modules ###
# hidden attributes
class _Hidden:
def __bool__(self):
return False
def __copy__(self):
return self
def __deepcopy__(self, _):
return self
def __eq__(self, other):
return str(other) == '<hidden>'
def __repr__(self):
return '<hidden>'
__str__ = __repr__
yaml.add_representer(
_Hidden,
lambda dumper, data: dumper.represent_data(str(data))
)
HIDDEN = _Hidden()
# multiline attributes
class _Multiline(str):
pass
yaml.add_representer(
_Multiline,
lambda dumper, data: dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')
)
# yaml render module
class RenderYAML:
""" Marshmallow YAML Render Module
"""
class SpacedDumper(yaml.Dumper):
""" YAML Dumper to add a newline between main sections
and double the indent used
"""
def write_line_break(self, data=None):
super().write_line_break(data)
if len(self.indents) == 1:
super().write_line_break()
def increase_indent(self, flow=False, indentless=False):
return super().increase_indent(flow, False)
@staticmethod
def _augment(kwargs, defaults):
""" add defaults to kwargs if missing
"""
for key, value in defaults.items():
if key not in kwargs:
kwargs[key] = value
_load_defaults = {}
@classmethod
def loads(cls, *args, **kwargs):
""" load yaml data from string
"""
cls._augment(kwargs, cls._load_defaults)
return yaml.safe_load(*args, **kwargs)
_dump_defaults = {
'Dumper': SpacedDumper,
'default_flow_style': False,
'allow_unicode': True,
'sort_keys': False,
}
@classmethod
def dumps(cls, *args, **kwargs):
""" dump data to yaml string
"""
cls._augment(kwargs, cls._dump_defaults)
return yaml.dump(*args, **kwargs)
# json encoder
class JSONEncoder(json.JSONEncoder):
""" JSONEncoder supporting serialization of HIDDEN """
def default(self, o):
""" serialize HIDDEN """
if isinstance(o, _Hidden):
return str(o)
return json.JSONEncoder.default(self, o)
# json render module
class RenderJSON:
""" Marshmallow JSON Render Module
"""
@staticmethod
def _augment(kwargs, defaults):
""" add defaults to kwargs if missing
"""
for key, value in defaults.items():
if key not in kwargs:
kwargs[key] = value
_load_defaults = {}
@classmethod
def loads(cls, *args, **kwargs):
""" load json data from string
"""
cls._augment(kwargs, cls._load_defaults)
return json.loads(*args, **kwargs)
_dump_defaults = {
'separators': (',',':'),
'cls': JSONEncoder,
}
@classmethod
def dumps(cls, *args, **kwargs):
""" dump data to json string
"""
cls._augment(kwargs, cls._dump_defaults)
return json.dumps(*args, **kwargs)
### marshmallow: custom fields ###
def _rfc3339(datetime):
""" dump datetime according to rfc3339 """
if datetime.tzinfo is None:
datetime = datetime.astimezone(timezone.utc)
res = datetime.isoformat()
if res.endswith('+00:00'):
return f'{res[:-6]}Z'
return res
fields.DateTime.SERIALIZATION_FUNCS['rfc3339'] = _rfc3339
fields.DateTime.DESERIALIZATION_FUNCS['rfc3339'] = fields.DateTime.DESERIALIZATION_FUNCS['iso']
fields.DateTime.DEFAULT_FORMAT = 'rfc3339'
class LazyStringField(fields.String):
""" Field that serializes a "false" value to the empty string
"""
def _serialize(self, value, attr, obj, **kwargs):
""" serialize None to the empty string
"""
return value if value else ''
class CommaSeparatedListField(fields.Raw):
""" Deserialize a string containing comma-separated values to
a list of strings
"""
default_error_messages = {
"invalid": "Not a valid string or list.",
"invalid_utf8": "Not a valid utf-8 string or list.",
}
def _deserialize(self, value, attr, data, **kwargs):
""" deserialize comma separated string to list of strings
"""
# empty
if not value:
return []
# handle list
if isinstance(value, list):
try:
value = [ensure_text_type(item) for item in value]
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
# handle text
else:
if not isinstance(value, (str, bytes)):
raise self.make_error("invalid")
try:
value = ensure_text_type(value)
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
else:
value = filter(bool, (item.strip() for item in value.split(',')))
return list(value)
class DkimKeyField(fields.String):
""" Serialize a dkim key to a multiline string and
deserialize a dkim key data as string or list of strings
to a valid dkim key
"""
default_error_messages = {
"invalid": "Not a valid string or list.",
"invalid_utf8": "Not a valid utf-8 string or list.",
}
def _serialize(self, value, attr, obj, **kwargs):
""" serialize dkim key as multiline string
"""
# map empty string and None to None
if not value:
return ''
# return multiline string
return _Multiline(value.decode('utf-8'))
def _wrap_key(self, begin, data, end):
""" generator to wrap key into RFC 7468 format """
yield begin
pos = 0
while pos < len(data):
yield data[pos:pos+64]
pos += 64
yield end
yield ''
def _deserialize(self, value, attr, data, **kwargs):
""" deserialize a string or list of strings to dkim key data
with verification
"""
# convert list to str
if isinstance(value, list):
try:
value = ''.join(ensure_text_type(item) for item in value).strip()
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
# only text is allowed
else:
if not isinstance(value, (str, bytes)):
raise self.make_error("invalid")
try:
value = ensure_text_type(value).strip()
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
# generate new key?
if value.lower() == '-generate-':
return dkim.gen_key()
# no key?
if not value:
return None
# remember part of value for ValidationError
bad_key = value
# strip header and footer, clean whitespace and wrap to 64 characters
try:
if value.startswith('-----BEGIN '):
end = value.index('-----', 11) + 5
header = value[:end]
value = value[end:]
else:
header = '-----BEGIN PRIVATE KEY-----'
if (pos := value.find('-----END ')) >= 0:
end = value.index('-----', pos+9) + 5
footer = value[pos:end]
value = value[:pos]
else:
footer = '-----END PRIVATE KEY-----'
except ValueError as exc:
raise ValidationError(f'invalid dkim key {bad_key!r}') from exc
# remove whitespace from key data
value = ''.join(value.split())
# remember part of value for ValidationError
bad_key = f'{value[:25]}...{value[-10:]}' if len(value) > 40 else value
# wrap key according to RFC 7468
value = ('\n'.join(self._wrap_key(header, value, footer))).encode('ascii')
# check key validity
try:
crypto.load_privatekey(crypto.FILETYPE_PEM, value)
except crypto.Error as exc:
raise ValidationError(f'invalid dkim key {bad_key!r}') from exc
else:
return value
class PasswordField(fields.Str):
""" Serialize a hashed password hash by stripping the obsolete {SCHEME}
Deserialize a plain password or hashed password into a hashed password
"""
_hashes = {'PBKDF2', 'BLF-CRYPT', 'SHA512-CRYPT', 'SHA256-CRYPT', 'MD5-CRYPT', 'CRYPT'}
def _serialize(self, value, attr, obj, **kwargs):
""" strip obsolete {password-hash} when serializing """
# strip scheme spec if in database - it's obsolete
if value.startswith('{') and (end := value.find('}', 1)) >= 0:
if value[1:end] in self._hashes:
return value[end+1:]
return value
def _deserialize(self, value, attr, data, **kwargs):
""" hashes plain password or checks hashed password
also strips obsolete {password-hash} when deserializing
"""
# when hashing is requested: use model instance to hash plain password
if data.get('hash_password'):
# hash password using model instance
inst = self.metadata['model']()
inst.set_password(value)
value = inst.password
del inst
# strip scheme spec when specified - it's obsolete
if value.startswith('{') and (end := value.find('}', 1)) >= 0:
if value[1:end] in self._hashes:
value = value[end+1:]
# check if algorithm is supported
inst = self.metadata['model'](password=value)
try:
# just check against empty string to see if hash is valid
inst.check_password('')
except ValueError as exc:
# ValueError: hash could not be identified
raise ValidationError(f'invalid password hash {value!r}') from exc
del inst
return value
### base schema ###
class Storage:
""" Storage class to save information in context
"""
context = {}
def _bind(self, key, bind):
if bind is True:
return (self.__class__, key)
if isinstance(bind, str):
return (get_schema(self.recall(bind).__class__), key)
return (bind, key)
def store(self, key, value, bind=None):
""" store value under key """
self.context.setdefault('_track', {})[self._bind(key, bind)]= value
def recall(self, key, bind=None):
""" recall value from key """
return self.context['_track'][self._bind(key, bind)]
class BaseOpts(SQLAlchemyAutoSchemaOpts):
""" Option class with sqla session
"""
def __init__(self, meta, ordered=False):
if not hasattr(meta, 'sqla_session'):
meta.sqla_session = models.db.session
if not hasattr(meta, 'sibling'):
meta.sibling = False
super(BaseOpts, self).__init__(meta, ordered=ordered)
class BaseSchema(ma.SQLAlchemyAutoSchema, Storage):
""" Marshmallow base schema with custom exclude logic
and option to hide sqla defaults
"""
OPTIONS_CLASS = BaseOpts
class Meta:
""" Schema config """
include_by_context = {}
exclude_by_value = {}
hide_by_context = {}
order = []
sibling = False
def __init__(self, *args, **kwargs):
# prepare only to auto-include explicitly specified attributes
only = set(kwargs.get('only') or [])
# get context
context = kwargs.get('context', {})
flags = {key for key, value in context.items() if value is True}
# compile excludes
exclude = set(kwargs.get('exclude', []))
# always exclude
exclude.update({'created_at', 'updated_at'} - only)
# add include_by_context
if context is not None:
for need, what in getattr(self.Meta, 'include_by_context', {}).items():
if not flags & set(need):
exclude |= what - only
# update excludes
kwargs['exclude'] = exclude
# init SQLAlchemyAutoSchema
super().__init__(*args, **kwargs)
# exclude_by_value
self._exclude_by_value = {
key: values for key, values in getattr(self.Meta, 'exclude_by_value', {}).items()
if key not in only
}
# exclude default values
if not context.get('full'):
for column in self.opts.model.__table__.columns:
if column.name not in exclude and column.name not in only:
self._exclude_by_value.setdefault(column.name, []).append(
None if column.default is None else column.default.arg
)
# hide by context
self._hide_by_context = set()
if context is not None:
for need, what in getattr(self.Meta, 'hide_by_context', {}).items():
if not flags & set(need):
self._hide_by_context |= what - only
# remember primary keys
self._primary = str(self.opts.model.__table__.primary_key.columns.values()[0].name)
# determine attribute order
if hasattr(self.Meta, 'order'):
# use user-defined order
order = self.Meta.order
else:
# default order is: primary_key + other keys alphabetically
order = list(sorted(self.fields.keys()))
if self._primary in order:
order.remove(self._primary)
order.insert(0, self._primary)
# order fieldlists
for fieldlist in (self.fields, self.load_fields, self.dump_fields):
for field in order:
if field in fieldlist:
fieldlist[field] = fieldlist.pop(field)
# move post_load hook "_add_instance" to the end (after load_instance mixin)
hooks = self._hooks[('post_load', False)]
hooks.remove('_add_instance')
hooks.append('_add_instance')
def hide(self, data):
""" helper method to hide input data for logging """
# always returns a copy of data
return {
key: HIDDEN if key in self._hide_by_context else deepcopy(value)
for key, value in data.items()
}
def _call_and_store(self, *args, **kwargs):
""" track current parent field for pruning """
self.store('field', kwargs['field_name'], True)
return super()._call_and_store(*args, **kwargs)
# this is only needed to work around the declared attr "email" primary key in model
def get_instance(self, data):
""" lookup item by defined primary key instead of key(s) from model """
if self.transient:
return None
if keys := getattr(self.Meta, 'primary_keys', None):
filters = {key: data.get(key) for key in keys}
if None not in filters.values():
res= self.session.query(self.opts.model).filter_by(**filters).first()
return res
res= super().get_instance(data)
return res
@pre_load(pass_many=True)
def _patch_many(self, items, many, **kwargs): # pylint: disable=unused-argument
""" - flush sqla session before serializing a section when requested
(make sure all objects that could be referred to later are created)
- when in update mode: patch input data before deserialization
- handle "prune" and "delete" items
- replace values in keys starting with '-' with default
"""
# flush sqla session
if not self.Meta.sibling:
self.opts.sqla_session.flush()
# stop early when not updating
if not self.context.get('update'):
return items
# patch "delete", "prune" and "default"
want_prune = []
def patch(count, data):
# don't allow __delete__ coming from input
if '__delete__' in data:
raise ValidationError('Unknown field.', f'{count}.__delete__')
# fail when hash_password is specified without password
if 'hash_password' in data and not 'password' in data:
raise ValidationError(
'Nothing to hash. Field "password" is missing.',
field_name = f'{count}.hash_password',
)
# handle "prune list" and "delete item" (-pkey: none and -pkey: id)
for key in data:
if key.startswith('-'):
if key[1:] == self._primary:
# delete or prune
if data[key] is None:
# prune
want_prune.append(True)
return None
# mark item for deletion
return {key[1:]: data[key], '__delete__': count}
# handle "set to default value" (-key: none)
def set_default(key, value):
if not key.startswith('-'):
return (key, value)
key = key[1:]
if not key in self.opts.model.__table__.columns:
return (key, None)
if value is not None:
raise ValidationError(
'Value must be "null" when resetting to default.',
f'{count}.{key}'
)
value = self.opts.model.__table__.columns[key].default
if value is None:
raise ValidationError(
'Field has no default value.',
f'{count}.{key}'
)
return (key, value.arg)
return dict(set_default(key, value) for key, value in data.items())
# convert items to "delete" and filter "prune" item
items = [
item for item in [
patch(count, item) for count, item in enumerate(items)
] if item
]
# remember if prune was requested for _prune_items@post_load
self.store('prune', bool(want_prune), True)
# remember original items to stabilize password-changes in _add_instance@post_load
self.store('original', items, True)
return items
@pre_load
def _patch_item(self, data, many, **kwargs): # pylint: disable=unused-argument
""" - call callback function to track import
- stabilize import of items with auto-increment primary key
- delete items
- delete/prune list attributes
- add missing required attributes
"""
# callback
if callback := self.context.get('callback'):
callback(self, data)
# stop early when not updating
if not self.opts.load_instance or not self.context.get('update'):
return data
# stabilize import of auto-increment primary keys (not required),
# by matching import data to existing items and setting primary key
if not self._primary in data:
for item in getattr(self.recall('parent'), self.recall('field', 'parent')):
existing = self.dump(item, many=False)
this = existing.pop(self._primary)
if data == existing:
instance = item
data[self._primary] = this
break
# try to load instance
instance = self.instance or self.get_instance(data)
if instance is None:
if '__delete__' in data:
# deletion of non-existent item requested
raise ValidationError(
f'Item to delete not found: {data[self._primary]!r}.',
field_name = f'{data["__delete__"]}.{self._primary}',
)
else:
if self.context.get('update'):
# remember instance as parent for pruning siblings
if not self.Meta.sibling:
self.store('parent', instance)
# delete instance from session when marked
if '__delete__' in data:
self.opts.sqla_session.delete(instance)
# delete item from lists or prune lists
# currently: domain.alternatives, user.forward_destination,
# user.manager_of, aliases.destination
for key, value in data.items():
if not isinstance(self.fields.get(key), (
RelatedList, CommaSeparatedListField, fields.Raw)
) or not isinstance(value, list):
continue
# deduplicate new value
new_value = set(value)
# handle list pruning
if '-prune-' in value:
value.remove('-prune-')
new_value.remove('-prune-')
else:
for old in getattr(instance, key):
# using str() is okay for now (see above)
new_value.add(str(old))
# handle item deletion
for item in value:
if item.startswith('-'):
new_value.remove(item)
try:
new_value.remove(item[1:])
except KeyError as exc:
raise ValidationError(
f'Item to delete not found: {item[1:]!r}.',
field_name=f'?.{key}',
) from exc
# sort list of new values
data[key] = sorted(new_value)
# log backref modification not catched by modify hook
if isinstance(self.fields[key], RelatedList):
if callback := self.context.get('callback'):
before = {str(v) for v in getattr(instance, key)}
after = set(data[key])
if before != after:
callback(self, instance, {
'key': key,
'target': str(instance),
'before': before,
'after': after,
})
# add attributes required for validation from db
for attr_name, field_obj in self.load_fields.items():
if field_obj.required and attr_name not in data:
data[attr_name] = getattr(instance, attr_name)
return data
@post_load(pass_many=True)
def _prune_items(self, items, many, **kwargs): # pylint: disable=unused-argument
""" handle list pruning """
# stop early when not updating
if not self.context.get('update'):
return items
# get prune flag from _patch_many@pre_load
want_prune = self.recall('prune', True)
# prune: determine if existing items in db need to be added or marked for deletion
add_items = False
del_items = False
if self.Meta.sibling:
# parent prunes automatically
if not want_prune:
# no prune requested => add old items
add_items = True
else:
# parent does not prune automatically
if want_prune:
# prune requested => mark old items for deletion
del_items = True
if add_items or del_items:
existing = {item[self._primary] for item in items if self._primary in item}
for item in getattr(self.recall('parent'), self.recall('field', 'parent')):
key = getattr(item, self._primary)
if key not in existing:
if add_items:
items.append({self._primary: key})
else:
items.append({self._primary: key, '__delete__': '?'})
return items
@post_load
def _add_instance(self, item, many, **kwargs): # pylint: disable=unused-argument
""" - undo password change in existing instances when plain password did not change
- add new instances to sqla session
"""
if not item in self.opts.sqla_session:
self.opts.sqla_session.add(item)
return item
# stop early when not updating or item has no password attribute
if not self.context.get('update') or not hasattr(item, 'password'):
return item
# did we hash a new plaintext password?
original = None
pkey = getattr(item, self._primary)
for data in self.recall('original', True):
if 'hash_password' in data and data.get(self._primary) == pkey:
original = data['password']
break
if original is None:
# password was hashed by us
return item
# reset hash if plain password matches hash from db
if attr := getattr(sqlalchemy.inspect(item).attrs, 'password', None):
if attr.history.has_changes() and attr.history.deleted:
try:
# reset password hash
inst = type(item)(password=attr.history.deleted[-1])
if inst.check_password(original):
item.password = inst.password
except ValueError:
# hash in db is invalid
pass
else:
del inst
return item
@post_dump
def _hide_values(self, data, many, **kwargs): # pylint: disable=unused-argument
""" hide secrets """
# stop early when not excluding/hiding
if not self._exclude_by_value and not self._hide_by_context:
return data
# exclude or hide values
full = self.context.get('full')
return type(data)(
(key, HIDDEN if key in self._hide_by_context else value)
for key, value in data.items()
if full or key not in self._exclude_by_value or value not in self._exclude_by_value[key]
)
# this field is used to mark items for deletion
mark_delete = fields.Boolean(data_key='__delete__', load_only=True)
# TODO: this can be removed when comment is not nullable in model
comment = LazyStringField()
### schema definitions ###
@mapped
class DomainSchema(BaseSchema):
""" Marshmallow schema for Domain model """
class Meta:
""" Schema config """
model = models.Domain
load_instance = True
include_relationships = True
exclude = ['users', 'managers', 'aliases']
include_by_context = {
('dns',): {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'},
}
hide_by_context = {
('secrets',): {'dkim_key'},
}
exclude_by_value = {
'alternatives': [[]],
'dkim_key': [None],
'dkim_publickey': [None],
'dns_mx': [None],
'dns_spf': [None],
'dns_dkim': [None],
'dns_dmarc': [None],
}
dkim_key = DkimKeyField(allow_none=True)
dkim_publickey = fields.String(dump_only=True)
dns_mx = fields.String(dump_only=True)
dns_spf = fields.String(dump_only=True)
dns_dkim = fields.String(dump_only=True)
dns_dmarc = fields.String(dump_only=True)
@mapped
class TokenSchema(BaseSchema):
""" Marshmallow schema for Token model """
class Meta:
""" Schema config """
model = models.Token
load_instance = True
sibling = True
password = PasswordField(required=True, metadata={'model': models.User})
hash_password = fields.Boolean(load_only=True, missing=False)
@mapped
class FetchSchema(BaseSchema):
""" Marshmallow schema for Fetch model """
class Meta:
""" Schema config """
model = models.Fetch
load_instance = True
sibling = True
include_by_context = {
('full', 'import'): {'last_check', 'error'},
}
hide_by_context = {
('secrets',): {'password'},
}
@mapped
class UserSchema(BaseSchema):
""" Marshmallow schema for User model """
class Meta:
""" Schema config """
model = models.User
load_instance = True
include_relationships = True
exclude = ['_email', 'domain', 'localpart', 'domain_name', 'quota_bytes_used']
primary_keys = ['email']
exclude_by_value = {
'forward_destination': [[]],
'tokens': [[]],
'fetches': [[]],
'manager_of': [[]],
'reply_enddate': ['2999-12-31'],
'reply_startdate': ['1900-01-01'],
}
email = fields.String(required=True)
tokens = fields.Nested(TokenSchema, many=True)
fetches = fields.Nested(FetchSchema, many=True)
password = PasswordField(required=True, metadata={'model': models.User})
hash_password = fields.Boolean(load_only=True, missing=False)
@mapped
class AliasSchema(BaseSchema):
""" Marshmallow schema for Alias model """
class Meta:
""" Schema config """
model = models.Alias
load_instance = True
exclude = ['_email', 'domain', 'localpart', 'domain_name']
primary_keys = ['email']
exclude_by_value = {
'destination': [[]],
}
email = fields.String(required=True)
destination = CommaSeparatedListField()
@mapped
class ConfigSchema(BaseSchema):
""" Marshmallow schema for Config model """
class Meta:
""" Schema config """
model = models.Config
load_instance = True
@mapped
class RelaySchema(BaseSchema):
""" Marshmallow schema for Relay model """
class Meta:
""" Schema config """
model = models.Relay
load_instance = True
@mapped
class MailuSchema(Schema, Storage):
""" Marshmallow schema for complete Mailu config """
class Meta:
""" Schema config """
model = models.MailuConfig
render_module = RenderYAML
order = ['domain', 'user', 'alias', 'relay'] # 'config'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# order fieldlists
for fieldlist in (self.fields, self.load_fields, self.dump_fields):
for field in self.Meta.order:
if field in fieldlist:
fieldlist[field] = fieldlist.pop(field)
def _call_and_store(self, *args, **kwargs):
""" track current parent and field for pruning """
self.store('field', kwargs['field_name'], True)
self.store('parent', self.context.get('config'))
return super()._call_and_store(*args, **kwargs)
@pre_load
def _clear_config(self, data, many, **kwargs): # pylint: disable=unused-argument
""" create config object in context if missing
and clear it if requested
"""
if 'config' not in self.context:
self.context['config'] = models.MailuConfig()
if self.context.get('clear'):
self.context['config'].clear(
models = {field.nested.opts.model for field in self.fields.values()}
)
return data
@post_load
def _make_config(self, data, many, **kwargs): # pylint: disable=unused-argument
""" update and return config object """
config = self.context['config']
for section in self.Meta.order:
if section in data:
config.update(data[section], section)
return config
domain = fields.Nested(DomainSchema, many=True)
user = fields.Nested(UserSchema, many=True)
alias = fields.Nested(AliasSchema, many=True)
relay = fields.Nested(RelaySchema, many=True)
# config = fields.Nested(ConfigSchema, many=True)