1
0
mirror of https://github.com/Mailu/Mailu.git synced 2024-12-14 10:53:30 +02:00
Mailu/core/admin/mailu/schemas.py
2021-02-16 17:59:43 +01:00

946 lines
32 KiB
Python

""" Mailu marshmallow fields and schema
"""
from copy import deepcopy
from textwrap import wrap
import re
import json
import yaml
import sqlalchemy
from marshmallow import pre_load, post_load, post_dump, fields, Schema
from marshmallow.utils import ensure_text_type
from marshmallow.exceptions import ValidationError
from marshmallow_sqlalchemy import SQLAlchemyAutoSchemaOpts
from marshmallow_sqlalchemy.fields import RelatedList
from flask_marshmallow import Marshmallow
from OpenSSL import crypto
try:
from pygments import highlight
from pygments.token import Token
from pygments.lexers import get_lexer_by_name
from pygments.lexers.data import YamlLexer
from pygments.formatters import get_formatter_by_name
except ModuleNotFoundError:
canColorize = False
else:
canColorize = True
from . import models, dkim
ma = Marshmallow()
# TODO: how and where to mark keys as "required" while deserializing in api?
# - when modifying, nothing is required (only the primary key, but this key is in the uri)
# - the primary key from post data must not differ from the key in the uri
# - when creating all fields without default or auto-increment are required
# TODO: validate everything!
### class for hidden values ###
class _Hidden:
def __bool__(self):
return False
def __copy__(self):
return self
def __deepcopy__(self, _):
return self
def __eq__(self, other):
return str(other) == '<hidden>'
def __repr__(self):
return '<hidden>'
__str__ = __repr__
HIDDEN = _Hidden()
### map model to schema ###
_model2schema = {}
def get_schema(model=None):
""" return schema class for model or instance of model """
if model is None:
return _model2schema.values()
else:
return _model2schema.get(model) or _model2schema.get(model.__class__)
def mapped(cls):
""" register schema in model2schema map """
_model2schema[cls.Meta.model] = cls
return cls
### helper functions ###
def get_fieldspec(exc):
""" walk traceback to extract spec of invalid field from marshmallow """
path = []
tbck = exc.__traceback__
while tbck:
if tbck.tb_frame.f_code.co_name == '_serialize':
if 'attr' in tbck.tb_frame.f_locals:
path.append(tbck.tb_frame.f_locals['attr'])
elif tbck.tb_frame.f_code.co_name == '_init_fields':
path = '.'.join(path)
spec = ', '.join([f'{path}.{key}' for key in tbck.tb_frame.f_locals['invalid_fields']])
return spec
tbck = tbck.tb_next
return None
def colorize(data, lexer='yaml', formatter='terminal', color=None, strip=False):
""" add ANSI color to data """
if color is None:
# autodetect colorize
color = canColorize
if not color:
# no color wanted
return data
if not canColorize:
# want color, but not supported
raise ValueError('Please install pygments to colorize output')
scheme = {
Token: ('', ''),
Token.Name.Tag: ('cyan', 'brightcyan'),
Token.Literal.Scalar: ('green', 'green'),
Token.Literal.String: ('green', 'green'),
Token.Keyword.Constant: ('magenta', 'brightmagenta'),
Token.Literal.Number: ('magenta', 'brightmagenta'),
Token.Error: ('red', 'brightred'),
Token.Name: ('red', 'brightred'),
Token.Operator: ('red', 'brightred'),
}
class MyYamlLexer(YamlLexer):
""" colorize yaml constants and integers """
def get_tokens(self, text, unfiltered=False):
for typ, value in super().get_tokens(text, unfiltered):
if typ is Token.Literal.Scalar.Plain:
if value in {'true', 'false', 'null'}:
typ = Token.Keyword.Constant
elif value == HIDDEN:
typ = Token.Error
else:
try:
int(value, 10)
except ValueError:
try:
float(value)
except ValueError:
pass
else:
typ = Token.Literal.Number.Float
else:
typ = Token.Literal.Number.Integer
yield typ, value
res = highlight(
data,
MyYamlLexer() if lexer == 'yaml' else get_lexer_by_name(lexer),
get_formatter_by_name(formatter, colorscheme=scheme)
)
return res.rstrip('\n') if strip else res
### render modules ###
# allow yaml to represent hidden attributes
yaml.add_representer(
_Hidden,
lambda cls, data: cls.represent_data(str(data))
)
class RenderYAML:
""" Marshmallow YAML Render Module
"""
class SpacedDumper(yaml.Dumper):
""" YAML Dumper to add a newline between main sections
and double the indent used
"""
def write_line_break(self, data=None):
super().write_line_break(data)
if len(self.indents) == 1:
super().write_line_break()
def increase_indent(self, flow=False, indentless=False):
return super().increase_indent(flow, False)
@staticmethod
def _augment(kwargs, defaults):
""" add default kv's to kwargs if missing
"""
for key, value in defaults.items():
if key not in kwargs:
kwargs[key] = value
_load_defaults = {}
@classmethod
def loads(cls, *args, **kwargs):
""" load yaml data from string
"""
cls._augment(kwargs, cls._load_defaults)
return yaml.safe_load(*args, **kwargs)
_dump_defaults = {
'Dumper': SpacedDumper,
'default_flow_style': False,
'allow_unicode': True,
'sort_keys': False,
}
@classmethod
def dumps(cls, *args, **kwargs):
""" dump data to yaml string
"""
cls._augment(kwargs, cls._dump_defaults)
return yaml.dump(*args, **kwargs)
class JSONEncoder(json.JSONEncoder):
""" JSONEncoder supporting serialization of HIDDEN """
def default(self, o):
""" serialize HIDDEN """
if isinstance(o, _Hidden):
return str(o)
return json.JSONEncoder.default(self, o)
class RenderJSON:
""" Marshmallow JSON Render Module
"""
@staticmethod
def _augment(kwargs, defaults):
""" add default kv's to kwargs if missing
"""
for key, value in defaults.items():
if key not in kwargs:
kwargs[key] = value
_load_defaults = {}
@classmethod
def loads(cls, *args, **kwargs):
""" load json data from string
"""
cls._augment(kwargs, cls._load_defaults)
return json.loads(*args, **kwargs)
_dump_defaults = {
'separators': (',',':'),
'cls': JSONEncoder,
}
@classmethod
def dumps(cls, *args, **kwargs):
""" dump data to json string
"""
cls._augment(kwargs, cls._dump_defaults)
return json.dumps(*args, **kwargs)
### custom fields ###
class LazyStringField(fields.String):
""" Field that serializes a "false" value to the empty string
"""
def _serialize(self, value, attr, obj, **kwargs):
""" serialize None to the empty string
"""
return value if value else ''
class CommaSeparatedListField(fields.Raw):
""" Deserialize a string containing comma-separated values to
a list of strings
"""
def _deserialize(self, value, attr, data, **kwargs):
""" deserialize comma separated string to list of strings
"""
# empty
if not value:
return []
# split string
if isinstance(value, str):
return list([item.strip() for item in value.split(',') if item.strip()])
else:
return value
class DkimKeyField(fields.String):
""" Serialize a dkim key to a list of strings (lines) and
Deserialize a string or list of strings to a valid dkim key
"""
default_error_messages = {
"invalid": "Not a valid string or list.",
"invalid_utf8": "Not a valid utf-8 string or list.",
}
_clean_re = re.compile(
r'(^-----BEGIN (RSA )?PRIVATE KEY-----|-----END (RSA )?PRIVATE KEY-----$|\s+)',
flags=re.UNICODE
)
def _serialize(self, value, attr, obj, **kwargs):
""" serialize dkim key to a list of strings (lines)
"""
# map empty string and None to None
if not value:
return None
# return list of key lines without header/footer
return value.decode('utf-8').strip().split('\n')[1:-1]
def _deserialize(self, value, attr, data, **kwargs):
""" deserialize a string or list of strings to dkim key data
with verification
"""
# convert list to str
if isinstance(value, list):
try:
value = ''.join([ensure_text_type(item) for item in value])
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
# only text is allowed
else:
if not isinstance(value, (str, bytes)):
raise self.make_error("invalid")
try:
value = ensure_text_type(value)
except UnicodeDecodeError as exc:
raise self.make_error("invalid_utf8") from exc
# clean value (remove whitespace and header/footer)
value = self._clean_re.sub('', value.strip())
# map empty string/list to None
if not value:
return None
# handle special value 'generate'
elif value == 'generate':
return dkim.gen_key()
# remember some keydata for error message
keydata = f'{value[:25]}...{value[-10:]}' if len(value) > 40 else value
# wrap value into valid pem layout and check validity
value = (
'-----BEGIN PRIVATE KEY-----\n' +
'\n'.join(wrap(value, 64)) +
'\n-----END PRIVATE KEY-----\n'
).encode('ascii')
try:
crypto.load_privatekey(crypto.FILETYPE_PEM, value)
except crypto.Error as exc:
raise ValidationError(f'invalid dkim key {keydata!r}') from exc
else:
return value
class PasswordField(fields.Str):
""" Serialize a hashed password hash by stripping the obsolete {SCHEME}
Deserialize a plain password or hashed password into a hashed password
"""
_hashes = {'PBKDF2', 'BLF-CRYPT', 'SHA512-CRYPT', 'SHA256-CRYPT', 'MD5-CRYPT', 'CRYPT'}
def _serialize(self, value, attr, obj, **kwargs):
""" strip obsolete {password-hash} when serializing """
# strip scheme spec if in database - it's obsolete
if value.startswith('{') and (end := value.find('}', 1)) >= 0:
if value[1:end] in self._hashes:
return value[end+1:]
return value
def _deserialize(self, value, attr, data, **kwargs):
""" hashes plain password or checks hashed password
also strips obsolete {password-hash} when deserializing
"""
# when hashing is requested: use model instance to hash plain password
if data.get('hash_password'):
# hash password using model instance
inst = self.metadata['model']()
inst.set_password(value)
value = inst.password
del inst
# strip scheme spec when specified - it's obsolete
if value.startswith('{') and (end := value.find('}', 1)) >= 0:
if value[1:end] in self._hashes:
value = value[end+1:]
# check if algorithm is supported
inst = self.metadata['model'](password=value)
try:
# just check against empty string to see if hash is valid
inst.check_password('')
except ValueError as exc:
# ValueError: hash could not be identified
raise ValidationError(f'invalid password hash {value!r}') from exc
del inst
return value
### base schema ###
class BaseOpts(SQLAlchemyAutoSchemaOpts):
""" Option class with sqla session
"""
def __init__(self, meta, ordered=False):
if not hasattr(meta, 'sqla_session'):
meta.sqla_session = models.db.session
if not hasattr(meta, 'sibling'):
meta.sibling = False
super(BaseOpts, self).__init__(meta, ordered=ordered)
class BaseSchema(ma.SQLAlchemyAutoSchema):
""" Marshmallow base schema with custom exclude logic
and option to hide sqla defaults
"""
OPTIONS_CLASS = BaseOpts
class Meta:
""" Schema config """
include_by_context = {}
exclude_by_value = {}
hide_by_context = {}
order = []
sibling = False
def __init__(self, *args, **kwargs):
# get context
context = kwargs.get('context', {})
flags = {key for key, value in context.items() if value is True}
# compile excludes
exclude = set(kwargs.get('exclude', []))
# always exclude
exclude.update({'created_at', 'updated_at'})
# add include_by_context
if context is not None:
for need, what in getattr(self.Meta, 'include_by_context', {}).items():
if not flags & set(need):
exclude |= set(what)
# update excludes
kwargs['exclude'] = exclude
# init SQLAlchemyAutoSchema
super().__init__(*args, **kwargs)
# exclude_by_value
self._exclude_by_value = getattr(self.Meta, 'exclude_by_value', {})
# exclude default values
if not context.get('full'):
for column in self.opts.model.__table__.columns:
if column.name not in exclude:
self._exclude_by_value.setdefault(column.name, []).append(
None if column.default is None else column.default.arg
)
# hide by context
self._hide_by_context = set()
if context is not None:
for need, what in getattr(self.Meta, 'hide_by_context', {}).items():
if not flags & set(need):
self._hide_by_context |= set(what)
# remember primary keys
self._primary = str(self.opts.model.__table__.primary_key.columns.values()[0].name)
# determine attribute order
if hasattr(self.Meta, 'order'):
# use user-defined order
order = self.Meta.order
else:
# default order is: primary_key + other keys alphabetically
order = list(sorted(self.fields.keys()))
if self._primary in order:
order.remove(self._primary)
order.insert(0, self._primary)
# order dump_fields
for field in order:
if field in self.dump_fields:
self.dump_fields[field] = self.dump_fields.pop(field)
# move pre_load hook "_track_import" to the front
hooks = self._hooks[('pre_load', False)]
hooks.remove('_track_import')
hooks.insert(0, '_track_import')
# move pre_load hook "_add_instance" to the end
hooks.remove('_add_required')
hooks.append('_add_required')
# move post_load hook "_add_instance" to the end
hooks = self._hooks[('post_load', False)]
hooks.remove('_add_instance')
hooks.append('_add_instance')
def hide(self, data):
""" helper method to hide input data for logging """
# always returns a copy of data
return {
key: HIDDEN if key in self._hide_by_context else deepcopy(value)
for key, value in data.items()
}
def _call_and_store(self, *args, **kwargs):
""" track curent parent field for pruning """
self.context['parent_field'] = kwargs['field_name']
return super()._call_and_store(*args, **kwargs)
# this is only needed to work around the declared attr "email" primary key in model
def get_instance(self, data):
""" lookup item by defined primary key instead of key(s) from model """
if self.transient:
return None
if keys := getattr(self.Meta, 'primary_keys', None):
filters = {key: data.get(key) for key in keys}
if None not in filters.values():
return self.session.query(self.opts.model).filter_by(**filters).first()
return super().get_instance(data)
@pre_load(pass_many=True)
def _patch_input(self, items, many, **kwargs): # pylint: disable=unused-argument
""" - flush sqla session before serializing a section when requested
(make sure all objects that could be referred to later are created)
- when in update mode: patch input data before deserialization
- handle "prune" and "delete" items
- replace values in keys starting with '-' with default
"""
# flush sqla session
if not self.Meta.sibling:
self.opts.sqla_session.flush()
# stop early when not updating
if not self.context.get('update'):
return items
# patch "delete", "prune" and "default"
want_prune = []
def patch(count, data, prune):
# don't allow __delete__ coming from input
if '__delete__' in data:
raise ValidationError('Unknown field.', f'{count}.__delete__')
# handle "prune list" and "delete item" (-pkey: none and -pkey: id)
for key in data:
if key.startswith('-'):
if key[1:] == self._primary:
# delete or prune
if data[key] is None:
# prune
prune.append(True)
return None
# mark item for deletion
return {key[1:]: data[key], '__delete__': True}
# handle "set to default value" (-key: none)
def set_default(key, value):
if not key.startswith('-'):
return (key, value)
key = key[1:]
if not key in self.opts.model.__table__.columns:
return (key, None)
if value is not None:
raise ValidationError(
'When resetting to default value must be null.',
f'{count}.{key}'
)
value = self.opts.model.__table__.columns[key].default
if value is None:
raise ValidationError(
'Field has no default value.',
f'{count}.{key}'
)
return (key, value.arg)
return dict([set_default(key, value) for key, value in data.items()])
# convert items to "delete" and filter "prune" item
items = [
item for item in [
patch(count, item, want_prune) for count, item in enumerate(items)
] if item
]
# prune: determine if existing items in db need to be added or marked for deletion
add_items = False
del_items = False
if self.Meta.sibling:
# parent prunes automatically
if not want_prune:
# no prune requested => add old items
add_items = True
else:
# parent does not prune automatically
if want_prune:
# prune requested => mark old items for deletion
del_items = True
if add_items or del_items:
existing = {item[self._primary] for item in items if self._primary in item}
for item in getattr(self.context['parent'], self.context['parent_field']):
key = getattr(item, self._primary)
if key not in existing:
if add_items:
items.append({self._primary: key})
else:
items.append({self._primary: key, '__delete__': True})
return items
@pre_load
def _track_import(self, data, many, **kwargs): # pylint: disable=unused-argument
""" call callback function to track import
"""
# callback
if callback := self.context.get('callback'):
callback(self, data)
return data
@pre_load
def _add_required(self, data, many, **kwargs): # pylint: disable=unused-argument
""" when updating:
allow modification of existing items having required attributes
by loading existing value from db
"""
if not self.opts.load_instance or not self.context.get('update'):
return data
# stabilize import of auto-increment primary keys (not required),
# by matching import data to existing items and setting primary key
if not self._primary in data:
for item in getattr(self.context['parent'], self.context['parent_field']):
existing = self.dump(item, many=False)
this = existing.pop(self._primary)
if data == existing:
instance = item
data[self._primary] = this
break
# try to load instance
instance = self.instance or self.get_instance(data)
if instance is None:
if '__delete__' in data:
# deletion of non-existent item requested
raise ValidationError(
f'item to delete not found: {data[self._primary]!r}',
field_name=f'?.{self._primary}',
)
else:
if self.context.get('update'):
# remember instance as parent for pruning siblings
if not self.Meta.sibling:
self.context['parent'] = instance
# delete instance when marked
if '__delete__' in data:
self.opts.sqla_session.delete(instance)
# delete item from lists or prune lists
# currently: domain.alternatives, user.forward_destination,
# user.manager_of, aliases.destination
for key, value in data.items():
if not isinstance(self.fields[key], fields.Nested) and isinstance(value, list):
new_value = set(value)
# handle list pruning
if '-prune-' in value:
value.remove('-prune-')
new_value.remove('-prune-')
else:
for old in getattr(instance, key):
# using str() is okay for now (see above)
new_value.add(str(old))
# handle item deletion
for item in value:
if item.startswith('-'):
new_value.remove(item)
try:
new_value.remove(item[1:])
except KeyError as exc:
raise ValidationError(
f'item to delete not found: {item[1:]!r}',
field_name=f'?.{key}',
) from exc
# deduplicate and sort list
data[key] = sorted(new_value)
# log backref modification not catched by hook
if isinstance(self.fields[key], RelatedList):
if callback := self.context.get('callback'):
callback(self, instance, {
'key': key,
'target': str(instance),
'before': [str(v) for v in getattr(instance, key)],
'after': data[key],
})
# add attributes required for validation from db
# TODO: this will cause validation errors if value from database does not validate
# but there should not be an invalid value in the database
for attr_name, field_obj in self.load_fields.items():
if field_obj.required and attr_name not in data:
data[attr_name] = getattr(instance, attr_name)
return data
@post_load(pass_original=True)
def _add_instance(self, item, original, many, **kwargs): # pylint: disable=unused-argument
""" add new instances to sqla session """
if item in self.opts.sqla_session:
# item was modified
if 'hash_password' in original:
# stabilize import of passwords to be hashed,
# by not re-hashing an unchanged password
if attr := getattr(sqlalchemy.inspect(item).attrs, 'password', None):
if attr.history.has_changes() and attr.history.deleted:
try:
# reset password hash, if password was not changed
inst = type(item)(password=attr.history.deleted[-1])
if inst.check_password(original['password']):
item.password = inst.password
except ValueError:
# hash in db is invalid
pass
else:
del inst
else:
# new item
self.opts.sqla_session.add(item)
return item
@post_dump
def _hide_values(self, data, many, **kwargs): # pylint: disable=unused-argument
""" hide secrets and order output """
# stop early when not excluding/hiding
if not self._exclude_by_value and not self._hide_by_context:
return data
# exclude or hide values
full = self.context.get('full')
return type(data)([
(key, HIDDEN if key in self._hide_by_context else value)
for key, value in data.items()
if full or key not in self._exclude_by_value or value not in self._exclude_by_value[key]
])
# this field is used to mark items for deletion
mark_delete = fields.Boolean(data_key='__delete__', load_only=True)
# TODO: remove LazyStringField (when model was changed - IMHO comment should not be nullable)
comment = LazyStringField()
### schema definitions ###
@mapped
class DomainSchema(BaseSchema):
""" Marshmallow schema for Domain model """
class Meta:
""" Schema config """
model = models.Domain
load_instance = True
include_relationships = True
exclude = ['users', 'managers', 'aliases']
include_by_context = {
('dns',): {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'},
}
hide_by_context = {
('secrets',): {'dkim_key'},
}
exclude_by_value = {
'alternatives': [[]],
'dkim_key': [None],
'dkim_publickey': [None],
'dns_mx': [None],
'dns_spf': [None],
'dns_dkim': [None],
'dns_dmarc': [None],
}
dkim_key = DkimKeyField(allow_none=True)
dkim_publickey = fields.String(dump_only=True)
dns_mx = fields.String(dump_only=True)
dns_spf = fields.String(dump_only=True)
dns_dkim = fields.String(dump_only=True)
dns_dmarc = fields.String(dump_only=True)
@mapped
class TokenSchema(BaseSchema):
""" Marshmallow schema for Token model """
class Meta:
""" Schema config """
model = models.Token
load_instance = True
sibling = True
password = PasswordField(required=True, metadata={'model': models.User})
hash_password = fields.Boolean(load_only=True, missing=False)
@mapped
class FetchSchema(BaseSchema):
""" Marshmallow schema for Fetch model """
class Meta:
""" Schema config """
model = models.Fetch
load_instance = True
sibling = True
include_by_context = {
('full', 'import'): {'last_check', 'error'},
}
hide_by_context = {
('secrets',): {'password'},
}
@mapped
class UserSchema(BaseSchema):
""" Marshmallow schema for User model """
class Meta:
""" Schema config """
model = models.User
load_instance = True
include_relationships = True
exclude = ['_email', 'domain', 'localpart', 'domain_name', 'quota_bytes_used']
primary_keys = ['email']
exclude_by_value = {
'forward_destination': [[]],
'tokens': [[]],
'fetches': [[]],
'manager_of': [[]],
'reply_enddate': ['2999-12-31'],
'reply_startdate': ['1900-01-01'],
}
email = fields.String(required=True)
tokens = fields.Nested(TokenSchema, many=True)
fetches = fields.Nested(FetchSchema, many=True)
password = PasswordField(required=True, metadata={'model': models.User})
hash_password = fields.Boolean(load_only=True, missing=False)
@mapped
class AliasSchema(BaseSchema):
""" Marshmallow schema for Alias model """
class Meta:
""" Schema config """
model = models.Alias
load_instance = True
exclude = ['_email', 'domain', 'localpart', 'domain_name']
primary_keys = ['email']
exclude_by_value = {
'destination': [[]],
}
email = fields.String(required=True)
destination = CommaSeparatedListField()
@mapped
class ConfigSchema(BaseSchema):
""" Marshmallow schema for Config model """
class Meta:
""" Schema config """
model = models.Config
load_instance = True
@mapped
class RelaySchema(BaseSchema):
""" Marshmallow schema for Relay model """
class Meta:
""" Schema config """
model = models.Relay
load_instance = True
class MailuSchema(Schema):
""" Marshmallow schema for complete Mailu config """
class Meta:
""" Schema config """
render_module = RenderYAML
order = ['domain', 'user', 'alias', 'relay'] # 'config'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# order dump_fields
for field in self.Meta.order:
if field in self.dump_fields:
self.dump_fields[field] = self.dump_fields.pop(field)
def _call_and_store(self, *args, **kwargs):
""" track current parent and field for pruning """
self.context.update({
'parent': self.context.get('config'),
'parent_field': kwargs['field_name'],
})
return super()._call_and_store(*args, **kwargs)
@pre_load
def _clear_config(self, data, many, **kwargs): # pylint: disable=unused-argument
""" create config object in context if missing
and clear it if requested
"""
if 'config' not in self.context:
self.context['config'] = models.MailuConfig()
if self.context.get('clear'):
self.context['config'].clear(
models = {field.nested.opts.model for field in self.fields.values()}
)
return data
@post_load
def _make_config(self, data, many, **kwargs): # pylint: disable=unused-argument
""" update and return config object """
config = self.context['config']
for section in self.Meta.order:
if section in data:
config.update(data[section], section)
return config
domain = fields.Nested(DomainSchema, many=True)
user = fields.Nested(UserSchema, many=True)
alias = fields.Nested(AliasSchema, many=True)
relay = fields.Nested(RelaySchema, many=True)
# config = fields.Nested(ConfigSchema, many=True)