1
0
mirror of https://github.com/Mailu/Mailu.git synced 2025-01-18 03:21:36 +02:00

added config_import using marshmallow

This commit is contained in:
Alexander Graf 2021-01-14 01:11:04 +01:00
parent 7413f9b7b4
commit c24bff1c1b
4 changed files with 611 additions and 593 deletions

View File

@ -26,7 +26,7 @@ def register(app):
# add redirect to current api version # add redirect to current api version
@app.route(f'{ROOT}/') @app.route(f'{ROOT}/')
def redir(): def _redirect_to_active_api():
return redirect(url_for(f'{ACTIVE.blueprint.name}.root')) return redirect(url_for(f'{ACTIVE.blueprint.name}.root'))
# swagger ui config # swagger ui config

View File

@ -1,21 +1,25 @@
from mailu import models """ Mailu command line interface
from .schemas import MailuConfig, MailuSchema """
from flask import current_app as app
from flask.cli import FlaskGroup, with_appcontext
import sys
import os import os
import socket import socket
import uuid import uuid
import click import click
import yaml
import sys from flask import current_app as app
from flask.cli import FlaskGroup, with_appcontext
from marshmallow.exceptions import ValidationError
from . import models
from .schemas import MailuSchema
db = models.db db = models.db
@click.group(cls=FlaskGroup) @click.group(cls=FlaskGroup, context_settings={'help_option_names': ['-?', '-h', '--help']})
def mailu(): def mailu():
""" Mailu command line """ Mailu command line
""" """
@ -26,17 +30,17 @@ def mailu():
def advertise(): def advertise():
""" Advertise this server against statistic services. """ Advertise this server against statistic services.
""" """
if os.path.isfile(app.config["INSTANCE_ID_PATH"]): if os.path.isfile(app.config['INSTANCE_ID_PATH']):
with open(app.config["INSTANCE_ID_PATH"], "r") as handle: with open(app.config['INSTANCE_ID_PATH'], 'r') as handle:
instance_id = handle.read() instance_id = handle.read()
else: else:
instance_id = str(uuid.uuid4()) instance_id = str(uuid.uuid4())
with open(app.config["INSTANCE_ID_PATH"], "w") as handle: with open(app.config['INSTANCE_ID_PATH'], 'w') as handle:
handle.write(instance_id) handle.write(instance_id)
if not app.config["DISABLE_STATISTICS"]: if not app.config['DISABLE_STATISTICS']:
try: try:
socket.gethostbyname(app.config["STATS_ENDPOINT"].format(instance_id)) socket.gethostbyname(app.config['STATS_ENDPOINT'].format(instance_id))
except: except OSError:
pass pass
@ -171,156 +175,196 @@ def user_import(localpart, domain_name, password_hash, hash_scheme = None):
db.session.commit() db.session.commit()
yaml_sections = [ # @mailu.command()
('domains', models.Domain), # @click.option('-v', '--verbose', is_flag=True, help='Increase verbosity')
('relays', models.Relay), # @click.option('-d', '--delete-objects', is_flag=True, help='Remove objects not included in yaml')
('users', models.User), # @click.option('-n', '--dry-run', is_flag=True, help='Perform a trial run with no changes made')
('aliases', models.Alias), # @click.argument('source', metavar='[FILENAME|-]', type=click.File(mode='r'), default=sys.stdin)
# ('config', models.Config), # @with_appcontext
] # def config_update(verbose=False, delete_objects=False, dry_run=False, source=None):
# """ Update configuration with data from YAML-formatted input
# """
# try:
# new_config = yaml.safe_load(source)
# except (yaml.scanner.ScannerError, yaml.parser.ParserError) as exc:
# print(f'[ERROR] Invalid yaml: {exc}')
# sys.exit(1)
# else:
# if isinstance(new_config, str):
# print(f'[ERROR] Invalid yaml: {new_config!r}')
# sys.exit(1)
# elif new_config is None or not new_config:
# print('[ERROR] Empty yaml: Please pipe yaml into stdin')
# sys.exit(1)
# error = False
# tracked = {}
# for section, model in yaml_sections:
# items = new_config.get(section)
# if items is None:
# if delete_objects:
# print(f'[ERROR] Invalid yaml: Section "{section}" is missing')
# error = True
# break
# else:
# continue
# del new_config[section]
# if not isinstance(items, list):
# print(f'[ERROR] Section "{section}" must be a list, not {items.__class__.__name__}')
# error = True
# break
# elif not items:
# continue
# # create items
# for data in items:
# if verbose:
# print(f'Handling {model.__table__} data: {data!r}')
# try:
# changed = model.from_dict(data, delete_objects)
# except Exception as exc:
# print(f'[ERROR] {exc.args[0]} in data: {data}')
# error = True
# break
# for item, created in changed:
# if created is True:
# # flush newly created item
# db.session.add(item)
# db.session.flush()
# if verbose:
# print(f'Added {item!r}: {item.to_dict()}')
# else:
# print(f'Added {item!r}')
# elif created:
# # modified instance
# if verbose:
# for key, old, new in created:
# print(f'Updated {key!r} of {item!r}: {old!r} -> {new!r}')
# else:
# print(f'Updated {item!r}: {", ".join(sorted([kon[0] for kon in created]))}')
# # track primary key of all items
# tracked.setdefault(item.__class__, set()).update(set([item._dict_pval()]))
# if error:
# break
# # on error: stop early
# if error:
# print('[ERROR] An error occured. Not committing changes.')
# db.session.rollback()
# sys.exit(1)
# # are there sections left in new_config?
# if new_config:
# print(f'[ERROR] Unknown section(s) in yaml: {", ".join(sorted(new_config.keys()))}')
# error = True
# # test for conflicting domains
# domains = set()
# for model, items in tracked.items():
# if model in (models.Domain, models.Alternative, models.Relay):
# if domains & items:
# for fqdn in domains & items:
# print(f'[ERROR] Duplicate domain name used: {fqdn}')
# error = True
# domains.update(items)
# # delete items not tracked
# if delete_objects:
# for model, items in tracked.items():
# for item in model.query.all():
# if not item._dict_pval() in items:
# print(f'Deleted {item!r} {item}')
# db.session.delete(item)
# # don't commit when running dry
# if dry_run:
# print('Dry run. Not commiting changes.')
# db.session.rollback()
# else:
# db.session.commit()
SECTIONS = {'domains', 'relays', 'users', 'aliases'}
@mailu.command() @mailu.command()
@click.option('-v', '--verbose', is_flag=True, help='Increase verbosity') @click.option('-v', '--verbose', is_flag=True, help='Increase verbosity')
@click.option('-d', '--delete-objects', is_flag=True, help='Remove objects not included in yaml')
@click.option('-n', '--dry-run', is_flag=True, help='Perform a trial run with no changes made') @click.option('-n', '--dry-run', is_flag=True, help='Perform a trial run with no changes made')
@click.argument('source', metavar='[FILENAME|-]', type=click.File(mode='r'), default=sys.stdin)
@with_appcontext @with_appcontext
def config_update(verbose=False, delete_objects=False, dry_run=False, file=None): def config_import(verbose=False, dry_run=False, source=None):
"""sync configuration with data from YAML-formatted stdin""" """ Import configuration YAML
"""
out = (lambda *args: print('(DRY RUN)', *args)) if dry_run else print context = {
'verbose': verbose, # TODO: use callback function to be verbose?
'import': True,
}
try: try:
new_config = yaml.safe_load(sys.stdin) config = MailuSchema(context=context).loads(source)
except (yaml.scanner.ScannerError, yaml.parser.ParserError) as reason: except ValidationError as exc:
out(f'[ERROR] Invalid yaml: {reason}') print(f'[ERROR] {exc}')
# TODO: show nice errors
from pprint import pprint
pprint(exc.messages)
sys.exit(1) sys.exit(1)
else: else:
if type(new_config) is str: print(config)
out(f'[ERROR] Invalid yaml: {new_config!r}') print(MailuSchema().dumps(config))
sys.exit(1) # TODO: does not commit yet.
elif new_config is None or not len(new_config): # TODO: delete other entries?
out('[ERROR] Empty yaml: Please pipe yaml into stdin')
sys.exit(1)
error = False
tracked = {}
for section, model in yaml_sections:
items = new_config.get(section)
if items is None:
if delete_objects:
out(f'[ERROR] Invalid yaml: Section "{section}" is missing')
error = True
break
else:
continue
del new_config[section]
if type(items) is not list:
out(f'[ERROR] Section "{section}" must be a list, not {items.__class__.__name__}')
error = True
break
elif not items:
continue
# create items
for data in items:
if verbose:
out(f'Handling {model.__table__} data: {data!r}')
try:
changed = model.from_dict(data, delete_objects)
except Exception as reason:
out(f'[ERROR] {reason.args[0]} in data: {data}')
error = True
break
for item, created in changed:
if created is True:
# flush newly created item
db.session.add(item)
db.session.flush()
if verbose:
out(f'Added {item!r}: {item.to_dict()}')
else:
out(f'Added {item!r}')
elif len(created):
# modified instance
if verbose:
for key, old, new in created:
out(f'Updated {key!r} of {item!r}: {old!r} -> {new!r}')
else:
out(f'Updated {item!r}: {", ".join(sorted([kon[0] for kon in created]))}')
# track primary key of all items
tracked.setdefault(item.__class__, set()).update(set([item._dict_pval()]))
if error:
break
# on error: stop early
if error:
out('An error occured. Not committing changes.')
db.session.rollback()
sys.exit(1)
# are there sections left in new_config?
if new_config:
out(f'[ERROR] Unknown section(s) in yaml: {", ".join(sorted(new_config.keys()))}')
error = True
# test for conflicting domains
domains = set()
for model, items in tracked.items():
if model in (models.Domain, models.Alternative, models.Relay):
if domains & items:
for domain in domains & items:
out(f'[ERROR] Duplicate domain name used: {domain}')
error = True
domains.update(items)
# delete items not tracked
if delete_objects:
for model, items in tracked.items():
for item in model.query.all():
if not item._dict_pval() in items:
out(f'Deleted {item!r} {item}')
db.session.delete(item)
# don't commit when running dry # don't commit when running dry
if dry_run: if True: #dry_run:
print('Dry run. Not commiting changes.')
db.session.rollback() db.session.rollback()
else: else:
db.session.commit() db.session.commit()
@mailu.command() @mailu.command()
@click.option('-f', '--full', is_flag=True, help='Include default attributes') @click.option('-f', '--full', is_flag=True, help='Include attributes with default value')
@click.option('-s', '--secrets', is_flag=True, help='Include secrets (dkim-key, plain-text / not hashed)') @click.option('-s', '--secrets', is_flag=True,
help='Include secret attributes (dkim-key, passwords)')
@click.option('-d', '--dns', is_flag=True, help='Include dns records') @click.option('-d', '--dns', is_flag=True, help='Include dns records')
@click.option('-o', '--output-file', 'output', default=sys.stdout, type=click.File(mode='w'),
help='save yaml to file')
@click.argument('sections', nargs=-1) @click.argument('sections', nargs=-1)
@with_appcontext @with_appcontext
def config_dump(full=False, secrets=False, dns=False, sections=None): def config_dump(full=False, secrets=False, dns=False, output=None, sections=None):
"""dump configuration as YAML-formatted data to stdout """ Dump configuration as YAML to stdout or file
SECTIONS can be: domains, relays, users, aliases SECTIONS can be: domains, relays, users, aliases
""" """
try: if sections:
config = MailuConfig(sections) for section in sections:
except ValueError as reason: if section not in SECTIONS:
print(f'[ERROR] {reason}') print(f'[ERROR] Unknown section: {section!r}')
return 1 sys.exit(1)
sections = set(sections)
else:
sections = SECTIONS
MailuSchema(context={ context={
'full': full, 'full': full,
'secrets': secrets, 'secrets': secrets,
'dns': dns, 'dns': dns,
}).dumps(config, sys.stdout) }
MailuSchema(only=sections, context=context).dumps(models.MailuConfig(), output)
@mailu.command() @mailu.command()

View File

@ -1,23 +1,26 @@
from mailu import dkim """ Mailu config storage model
"""
from sqlalchemy.ext import declarative import re
from datetime import datetime, date import os
import smtplib
import json
from datetime import date
from email.mime import text from email.mime import text
from flask import current_app as app
from textwrap import wrap
import flask_sqlalchemy import flask_sqlalchemy
import sqlalchemy import sqlalchemy
import re
import time
import os
import passlib import passlib
import glob
import smtplib
import idna import idna
import dns import dns
import json
import itertools from flask import current_app as app
from sqlalchemy.ext import declarative
from sqlalchemy.inspection import inspect
from werkzeug.utils import cached_property
from . import dkim
db = flask_sqlalchemy.SQLAlchemy() db = flask_sqlalchemy.SQLAlchemy()
@ -30,9 +33,11 @@ class IdnaDomain(db.TypeDecorator):
impl = db.String(80) impl = db.String(80)
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
""" encode unicode domain name to punycode """
return idna.encode(value).decode('ascii').lower() return idna.encode(value).decode('ascii').lower()
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
""" decode punycode domain name to unicode """
return idna.decode(value) return idna.decode(value)
python_type = str python_type = str
@ -44,6 +49,7 @@ class IdnaEmail(db.TypeDecorator):
impl = db.String(255) impl = db.String(255)
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
""" encode unicode domain part of email address to punycode """
try: try:
localpart, domain_name = value.split('@') localpart, domain_name = value.split('@')
return '{0}@{1}'.format( return '{0}@{1}'.format(
@ -54,6 +60,7 @@ class IdnaEmail(db.TypeDecorator):
pass pass
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
""" decode punycode domain part of email to unicode """
localpart, domain_name = value.split('@') localpart, domain_name = value.split('@')
return '{0}@{1}'.format( return '{0}@{1}'.format(
localpart, localpart,
@ -69,14 +76,16 @@ class CommaSeparatedList(db.TypeDecorator):
impl = db.String impl = db.String
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
if not isinstance(value, (list, set)): """ join list of items to comma separated string """
raise TypeError('Must be a list') if not isinstance(value, (list, tuple, set)):
raise TypeError('Must be a list of strings')
for item in value: for item in value:
if ',' in item: if ',' in item:
raise ValueError('Item must not contain a comma') raise ValueError('Item must not contain a comma')
return ','.join(sorted(value)) return ','.join(sorted(value))
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
""" split comma separated string to list """
return list(filter(bool, value.split(','))) if value else [] return list(filter(bool, value.split(','))) if value else []
python_type = list python_type = list
@ -88,9 +97,11 @@ class JSONEncoded(db.TypeDecorator):
impl = db.String impl = db.String
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
""" encode data as json """
return json.dumps(value) if value else None return json.dumps(value) if value else None
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
""" decode json to data """
return json.loads(value) if value else None return json.loads(value) if value else None
python_type = str python_type = str
@ -112,246 +123,172 @@ class Base(db.Model):
updated_at = db.Column(db.Date, nullable=True, onupdate=date.today) updated_at = db.Column(db.Date, nullable=True, onupdate=date.today)
comment = db.Column(db.String(255), nullable=True, default='') comment = db.Column(db.String(255), nullable=True, default='')
@classmethod # @classmethod
def _dict_pkey(cls): # def from_dict(cls, data, delete=False):
return cls.__mapper__.primary_key[0].name
def _dict_pval(self): # changed = []
return getattr(self, self._dict_pkey())
def to_dict(self, full=False, include_secrets=False, include_extra=None, recursed=False, hide=None): # pkey = cls._dict_pkey()
""" Return a dictionary representation of this model.
"""
if recursed and not getattr(self, '_dict_recurse', False): # # handle "primary key" only
return str(self) # if not isinstance(data, dict):
# data = {pkey: data}
hide = set(hide or []) | {'created_at', 'updated_at'} # # modify input data
if hasattr(self, '_dict_hide'): # if hasattr(cls, '_dict_input'):
hide |= self._dict_hide # try:
# cls._dict_input(data)
# except Exception as exc:
# raise ValueError(f'{exc}', cls, None, data) from exc
secret = set() # # check for primary key (if not recursed)
if not include_secrets and hasattr(self, '_dict_secret'): # if not getattr(cls, '_dict_recurse', False):
secret |= self._dict_secret # if not pkey in data:
# raise KeyError(f'primary key {cls.__table__}.{pkey} is missing', cls, pkey, data)
convert = getattr(self, '_dict_output', {}) # # check data keys and values
# for key in list(data.keys()):
extra_keys = getattr(self, '_dict_extra', {}) # # check key
if include_extra is None: # if not hasattr(cls, key) and not key in cls.__mapper__.relationships:
include_extra = [] # raise KeyError(f'unknown key {cls.__table__}.{key}', cls, key, data)
res = {} # # check value type
# value = data[key]
# col = cls.__mapper__.columns.get(key)
# if col is not None:
# if not ((value is None and col.nullable) or (isinstance(value, col.type.python_type))):
# raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data)
# else:
# rel = cls.__mapper__.relationships.get(key)
# if rel is None:
# itype = getattr(cls, '_dict_types', {}).get(key)
# if itype is not None:
# if itype is False: # ignore value. TODO: emit warning?
# del data[key]
# continue
# elif not isinstance(value, itype):
# raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data)
# else:
# raise NotImplementedError(f'type not defined for {cls.__table__}.{key}')
for key in itertools.chain( # # handle relationships
self.__table__.columns.keys(), # if key in cls.__mapper__.relationships:
getattr(self, '_dict_show', []), # rel_model = cls.__mapper__.relationships[key].argument
*[extra_keys.get(extra, []) for extra in include_extra] # if not isinstance(rel_model, sqlalchemy.orm.Mapper):
): # add = rel_model.from_dict(value, delete)
if key in hide: # assert len(add) == 1
continue # rel_item, updated = add[0]
if key in self.__table__.columns: # changed.append((rel_item, updated))
default = self.__table__.columns[key].default # data[key] = rel_item
if isinstance(default, sqlalchemy.sql.schema.ColumnDefault):
default = default.arg
else:
default = None
value = getattr(self, key)
if full or ((default or value) and value != default):
if key in secret:
value = '<hidden>'
elif value is not None and key in convert:
value = convert[key](value)
res[key] = value
for key in self.__mapper__.relationships.keys(): # # create item if necessary
if key in hide: # created = False
continue # item = cls.query.get(data[pkey]) if pkey in data else None
if self.__mapper__.relationships[key].uselist: # if item is None:
items = getattr(self, key)
if self.__mapper__.relationships[key].query_class is not None:
if hasattr(items, 'all'):
items = items.all()
if full or items:
if key in secret:
res[key] = '<hidden>'
else:
res[key] = [item.to_dict(full, include_secrets, include_extra, True) for item in items]
else:
value = getattr(self, key)
if full or value is not None:
if key in secret:
res[key] = '<hidden>'
else:
res[key] = value.to_dict(full, include_secrets, include_extra, True)
return res # # check for mandatory keys
# missing = getattr(cls, '_dict_mandatory', set()) - set(data.keys())
# if missing:
# raise ValueError(f'mandatory key(s) {", ".join(sorted(missing))} for {cls.__table__} missing', cls, missing, data)
@classmethod # # remove mapped relationships from data
def from_dict(cls, data, delete=False): # mapped = {}
# for key in list(data.keys()):
# if key in cls.__mapper__.relationships:
# if isinstance(cls.__mapper__.relationships[key].argument, sqlalchemy.orm.Mapper):
# mapped[key] = data[key]
# del data[key]
changed = [] # # create new item
# item = cls(**data)
# created = True
pkey = cls._dict_pkey() # # and update mapped relationships (below)
# data = mapped
# handle "primary key" only # # update item
if isinstance(data, dict): # updated = []
data = {pkey: data} # for key, value in data.items():
# modify input data # # skip primary key
if hasattr(cls, '_dict_input'): # if key == pkey:
try: # continue
cls._dict_input(data)
except Exception as reason:
raise ValueError(f'{reason}', cls, None, data)
# check for primary key (if not recursed) # if key in cls.__mapper__.relationships:
if not getattr(cls, '_dict_recurse', False): # # update relationship
if not pkey in data: # rel_model = cls.__mapper__.relationships[key].argument
raise KeyError(f'primary key {cls.__table__}.{pkey} is missing', cls, pkey, data) # if isinstance(rel_model, sqlalchemy.orm.Mapper):
# rel_model = rel_model.class_
# # add (and create) referenced items
# cur = getattr(item, key)
# old = sorted(cur, key=id)
# new = []
# for rel_data in value:
# # get or create related item
# add = rel_model.from_dict(rel_data, delete)
# assert len(add) == 1
# rel_item, rel_updated = add[0]
# changed.append((rel_item, rel_updated))
# if rel_item not in cur:
# cur.append(rel_item)
# new.append(rel_item)
# check data keys and values # # delete referenced items missing in yaml
for key in list(data.keys()): # rel_pkey = rel_model._dict_pkey()
# new_data = list([i.to_dict(True, True, None, True, [rel_pkey]) for i in new])
# for rel_item in old:
# if rel_item not in new:
# # check if item with same data exists to stabilze import without primary key
# rel_data = rel_item.to_dict(True, True, None, True, [rel_pkey])
# try:
# same_idx = new_data.index(rel_data)
# except ValueError:
# same = None
# else:
# same = new[same_idx]
# check key # if same is None:
if not hasattr(cls, key) and not key in cls.__mapper__.relationships: # # delete items missing in new
raise KeyError(f'unknown key {cls.__table__}.{key}', cls, key, data) # if delete:
# cur.remove(rel_item)
# else:
# new.append(rel_item)
# else:
# # swap found item with same data with newly created item
# new.append(rel_item)
# new_data.append(rel_data)
# new.remove(same)
# del new_data[same_idx]
# for i, (ch_item, _) in enumerate(changed):
# if ch_item is same:
# changed[i] = (rel_item, [])
# db.session.flush()
# db.session.delete(ch_item)
# break
# check value type # # remember changes
value = data[key] # new = sorted(new, key=id)
col = cls.__mapper__.columns.get(key) # if new != old:
if col is not None: # updated.append((key, old, new))
if not ((value is None and col.nullable) or (isinstance(value, col.type.python_type))):
raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data)
else:
rel = cls.__mapper__.relationships.get(key)
if rel is None:
itype = getattr(cls, '_dict_types', {}).get(key)
if itype is not None:
if itype is False: # ignore value. TODO: emit warning?
del data[key]
continue
elif not isinstance(value, itype):
raise TypeError(f'{cls.__table__}.{key} {value!r} has invalid type {type(value).__name__!r}', cls, key, data)
else:
raise NotImplementedError(f'type not defined for {cls.__table__}.{key}')
# handle relationships # else:
if key in cls.__mapper__.relationships: # # update key
rel_model = cls.__mapper__.relationships[key].argument # old = getattr(item, key)
if not isinstance(rel_model, sqlalchemy.orm.Mapper): # if isinstance(old, list):
add = rel_model.from_dict(value, delete) # # deduplicate list value
assert len(add) == 1 # assert isinstance(value, list)
rel_item, updated = add[0] # value = set(value)
changed.append((rel_item, updated)) # old = set(old)
data[key] = rel_item # if not delete:
# value = old | value
# if value != old:
# updated.append((key, old, value))
# setattr(item, key, value)
# create item if necessary # changed.append((item, created if created else updated))
created = False
item = cls.query.get(data[pkey]) if pkey in data else None
if item is None:
# check for mandatory keys # return changed
missing = getattr(cls, '_dict_mandatory', set()) - set(data.keys())
if missing:
raise ValueError(f'mandatory key(s) {", ".join(sorted(missing))} for {cls.__table__} missing', cls, missing, data)
# remove mapped relationships from data
mapped = {}
for key in list(data.keys()):
if key in cls.__mapper__.relationships:
if isinstance(cls.__mapper__.relationships[key].argument, sqlalchemy.orm.Mapper):
mapped[key] = data[key]
del data[key]
# create new item
item = cls(**data)
created = True
# and update mapped relationships (below)
data = mapped
# update item
updated = []
for key, value in data.items():
# skip primary key
if key == pkey:
continue
if key in cls.__mapper__.relationships:
# update relationship
rel_model = cls.__mapper__.relationships[key].argument
if isinstance(rel_model, sqlalchemy.orm.Mapper):
rel_model = rel_model.class_
# add (and create) referenced items
cur = getattr(item, key)
old = sorted(cur, key=id)
new = []
for rel_data in value:
# get or create related item
add = rel_model.from_dict(rel_data, delete)
assert len(add) == 1
rel_item, rel_updated = add[0]
changed.append((rel_item, rel_updated))
if rel_item not in cur:
cur.append(rel_item)
new.append(rel_item)
# delete referenced items missing in yaml
rel_pkey = rel_model._dict_pkey()
new_data = list([i.to_dict(True, True, None, True, [rel_pkey]) for i in new])
for rel_item in old:
if rel_item not in new:
# check if item with same data exists to stabilze import without primary key
rel_data = rel_item.to_dict(True, True, None, True, [rel_pkey])
try:
same_idx = new_data.index(rel_data)
except ValueError:
same = None
else:
same = new[same_idx]
if same is None:
# delete items missing in new
if delete:
cur.remove(rel_item)
else:
new.append(rel_item)
else:
# swap found item with same data with newly created item
new.append(rel_item)
new_data.append(rel_data)
new.remove(same)
del new_data[same_idx]
for i, (ch_item, _) in enumerate(changed):
if ch_item is same:
changed[i] = (rel_item, [])
db.session.flush()
db.session.delete(ch_item)
break
# remember changes
new = sorted(new, key=id)
if new != old:
updated.append((key, old, new))
else:
# update key
old = getattr(item, key)
if isinstance(old, list):
# deduplicate list value
assert isinstance(value, list)
value = set(value)
old = set(old)
if not delete:
value = old | value
if value != old:
updated.append((key, old, value))
setattr(item, key, value)
changed.append((item, created if created else updated))
return changed
# Many-to-many association table for domain managers # Many-to-many association table for domain managers
@ -391,48 +328,6 @@ class Domain(Base):
__tablename__ = 'domain' __tablename__ = 'domain'
_dict_hide = {'users', 'managers', 'aliases'}
_dict_show = {'dkim_key'}
_dict_extra = {'dns':{'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'}}
_dict_secret = {'dkim_key'}
_dict_types = {
'dkim_key': (bytes, type(None)),
'dkim_publickey': False,
'dns_mx': False,
'dns_spf': False,
'dns_dkim': False,
'dns_dmarc': False,
}
_dict_output = {'dkim_key': lambda key: key.decode('utf-8').strip().split('\n')[1:-1]}
@staticmethod
def _dict_input(data):
if 'dkim_key' in data:
key = data['dkim_key']
if key is not None:
if isinstance(key, list):
key = ''.join(key)
if isinstance(key, str):
key = ''.join(key.strip().split()) # removes all whitespace
if key == 'generate':
data['dkim_key'] = dkim.gen_key()
elif key:
match = re.match('^-----BEGIN (RSA )?PRIVATE KEY-----', key)
if match is not None:
key = key[match.end():]
match = re.search('-----END (RSA )?PRIVATE KEY-----$', key)
if match is not None:
key = key[:match.start()]
key = '\n'.join(wrap(key, 64))
key = f'-----BEGIN PRIVATE KEY-----\n{key}\n-----END PRIVATE KEY-----\n'.encode('ascii')
try:
dkim.strip_key(key)
except:
raise ValueError('invalid dkim key')
else:
data['dkim_key'] = key
else:
data['dkim_key'] = None
name = db.Column(IdnaDomain, primary_key=True, nullable=False) name = db.Column(IdnaDomain, primary_key=True, nullable=False)
managers = db.relationship('User', secondary=managers, managers = db.relationship('User', secondary=managers,
backref=db.backref('manager_of'), lazy='dynamic') backref=db.backref('manager_of'), lazy='dynamic')
@ -440,7 +335,7 @@ class Domain(Base):
max_aliases = db.Column(db.Integer, nullable=False, default=-1) max_aliases = db.Column(db.Integer, nullable=False, default=-1)
max_quota_bytes = db.Column(db.BigInteger, nullable=False, default=0) max_quota_bytes = db.Column(db.BigInteger, nullable=False, default=0)
signup_enabled = db.Column(db.Boolean, nullable=False, default=False) signup_enabled = db.Column(db.Boolean, nullable=False, default=False)
_dkim_key = None _dkim_key = None
_dkim_key_changed = False _dkim_key_changed = False
@ -452,17 +347,20 @@ class Domain(Base):
def dns_mx(self): def dns_mx(self):
hostname = app.config['HOSTNAMES'].split(',')[0] hostname = app.config['HOSTNAMES'].split(',')[0]
return f'{self.name}. 600 IN MX 10 {hostname}.' return f'{self.name}. 600 IN MX 10 {hostname}.'
@property @property
def dns_spf(self): def dns_spf(self):
hostname = app.config['HOSTNAMES'].split(',')[0] hostname = app.config['HOSTNAMES'].split(',')[0]
return f'{self.name}. 600 IN TXT "v=spf1 mx a:{hostname} ~all"' return f'{self.name}. 600 IN TXT "v=spf1 mx a:{hostname} ~all"'
@property @property
def dns_dkim(self): def dns_dkim(self):
if os.path.exists(self._dkim_file()): if os.path.exists(self._dkim_file()):
selector = app.config['DKIM_SELECTOR'] selector = app.config['DKIM_SELECTOR']
return f'{selector}._domainkey.{self.name}. 600 IN TXT "v=DKIM1; k=rsa; p={self.dkim_publickey}"' return (
f'{selector}._domainkey.{self.name}. 600 IN TXT'
f'"v=DKIM1; k=rsa; p={self.dkim_publickey}"'
)
@property @property
def dns_dmarc(self): def dns_dmarc(self):
@ -473,7 +371,7 @@ class Domain(Base):
ruf = app.config['DMARC_RUF'] ruf = app.config['DMARC_RUF']
ruf = f' ruf=mailto:{ruf}@{domain};' if ruf else '' ruf = f' ruf=mailto:{ruf}@{domain};' if ruf else ''
return f'_dmarc.{self.name}. 600 IN TXT "v=DMARC1; p=reject;{rua}{ruf} adkim=s; aspf=s"' return f'_dmarc.{self.name}. 600 IN TXT "v=DMARC1; p=reject;{rua}{ruf} adkim=s; aspf=s"'
@property @property
def dkim_key(self): def dkim_key(self):
if self._dkim_key is None: if self._dkim_key is None:
@ -525,7 +423,11 @@ class Domain(Base):
try: try:
return self.name == other.name return self.name == other.name
except AttributeError: except AttributeError:
return False return NotImplemented
def __hash__(self):
return hash(str(self.name))
class Alternative(Base): class Alternative(Base):
@ -551,8 +453,6 @@ class Relay(Base):
__tablename__ = 'relay' __tablename__ = 'relay'
_dict_mandatory = {'smtp'}
name = db.Column(IdnaDomain, primary_key=True, nullable=False) name = db.Column(IdnaDomain, primary_key=True, nullable=False)
smtp = db.Column(db.String(80), nullable=True) smtp = db.Column(db.String(80), nullable=True)
@ -566,18 +466,8 @@ class Email(object):
localpart = db.Column(db.String(80), nullable=False) localpart = db.Column(db.String(80), nullable=False)
@staticmethod
def _dict_input(data):
if 'email' in data:
if 'localpart' in data or 'domain' in data:
raise ValueError('ambigous key email and localpart/domain')
elif isinstance(data['email'], str):
data['localpart'], data['domain'] = data['email'].rsplit('@', 1)
else:
data['email'] = f'{data["localpart"]}@{data["domain"]}'
@declarative.declared_attr @declarative.declared_attr
def domain_name(cls): def domain_name(self):
return db.Column(IdnaDomain, db.ForeignKey(Domain.name), return db.Column(IdnaDomain, db.ForeignKey(Domain.name),
nullable=False, default=IdnaDomain) nullable=False, default=IdnaDomain)
@ -585,7 +475,7 @@ class Email(object):
# It is however very useful for quick lookups without joining tables, # It is however very useful for quick lookups without joining tables,
# especially when the mail server is reading the database. # especially when the mail server is reading the database.
@declarative.declared_attr @declarative.declared_attr
def email(cls): def email(self):
updater = lambda context: '{0}@{1}'.format( updater = lambda context: '{0}@{1}'.format(
context.current_parameters['localpart'], context.current_parameters['localpart'],
context.current_parameters['domain_name'], context.current_parameters['domain_name'],
@ -662,30 +552,6 @@ class User(Base, Email):
__tablename__ = 'user' __tablename__ = 'user'
_dict_hide = {'domain_name', 'domain', 'localpart', 'quota_bytes_used'}
_dict_mandatory = {'localpart', 'domain', 'password'}
@classmethod
def _dict_input(cls, data):
Email._dict_input(data)
# handle password
if 'password' in data:
if 'password_hash' in data or 'hash_scheme' in data:
raise ValueError('ambigous key password and password_hash/hash_scheme')
# check (hashed) password
password = data['password']
if password.startswith('{') and '}' in password:
scheme = password[1:password.index('}')]
if scheme not in cls.scheme_dict:
raise ValueError(f'invalid password scheme {scheme!r}')
else:
raise ValueError(f'invalid hashed password {password!r}')
elif 'password_hash' in data and 'hash_scheme' in data:
if data['hash_scheme'] not in cls.scheme_dict:
raise ValueError(f'invalid password scheme {scheme!r}')
data['password'] = '{'+data['hash_scheme']+'}'+ data['password_hash']
del data['hash_scheme']
del data['password_hash']
domain = db.relationship(Domain, domain = db.relationship(Domain,
backref=db.backref('users', cascade='all, delete-orphan')) backref=db.backref('users', cascade='all, delete-orphan'))
password = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False)
@ -775,7 +641,8 @@ class User(Base, Email):
if raw: if raw:
self.password = '{'+hash_scheme+'}' + password self.password = '{'+hash_scheme+'}' + password
else: else:
self.password = '{'+hash_scheme+'}' + self.get_password_context().encrypt(password, self.scheme_dict[hash_scheme]) self.password = '{'+hash_scheme+'}' + \
self.get_password_context().encrypt(password, self.scheme_dict[hash_scheme])
def get_managed_domains(self): def get_managed_domains(self):
if self.global_admin: if self.global_admin:
@ -812,15 +679,6 @@ class Alias(Base, Email):
__tablename__ = 'alias' __tablename__ = 'alias'
_dict_hide = {'domain_name', 'domain', 'localpart'}
@staticmethod
def _dict_input(data):
Email._dict_input(data)
# handle comma delimited string for backwards compability
dst = data.get('destination')
if isinstance(dst, str):
data['destination'] = list([adr.strip() for adr in dst.split(',')])
domain = db.relationship(Domain, domain = db.relationship(Domain,
backref=db.backref('aliases', cascade='all, delete-orphan')) backref=db.backref('aliases', cascade='all, delete-orphan'))
wildcard = db.Column(db.Boolean, nullable=False, default=False) wildcard = db.Column(db.Boolean, nullable=False, default=False)
@ -832,10 +690,10 @@ class Alias(Base, Email):
sqlalchemy.and_(cls.domain_name == domain_name, sqlalchemy.and_(cls.domain_name == domain_name,
sqlalchemy.or_( sqlalchemy.or_(
sqlalchemy.and_( sqlalchemy.and_(
cls.wildcard == False, cls.wildcard is False,
cls.localpart == localpart cls.localpart == localpart
), sqlalchemy.and_( ), sqlalchemy.and_(
cls.wildcard == True, cls.wildcard is True,
sqlalchemy.bindparam('l', localpart).like(cls.localpart) sqlalchemy.bindparam('l', localpart).like(cls.localpart)
) )
) )
@ -847,10 +705,10 @@ class Alias(Base, Email):
sqlalchemy.and_(cls.domain_name == domain_name, sqlalchemy.and_(cls.domain_name == domain_name,
sqlalchemy.or_( sqlalchemy.or_(
sqlalchemy.and_( sqlalchemy.and_(
cls.wildcard == False, cls.wildcard is False,
sqlalchemy.func.lower(cls.localpart) == localpart_lower sqlalchemy.func.lower(cls.localpart) == localpart_lower
), sqlalchemy.and_( ), sqlalchemy.and_(
cls.wildcard == True, cls.wildcard is True,
sqlalchemy.bindparam('l', localpart_lower).like(sqlalchemy.func.lower(cls.localpart)) sqlalchemy.bindparam('l', localpart_lower).like(sqlalchemy.func.lower(cls.localpart))
) )
) )
@ -875,10 +733,6 @@ class Token(Base):
__tablename__ = 'token' __tablename__ = 'token'
_dict_recurse = True
_dict_hide = {'user', 'user_email'}
_dict_mandatory = {'password'}
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
user_email = db.Column(db.String(255), db.ForeignKey(User.email), user_email = db.Column(db.String(255), db.ForeignKey(User.email),
nullable=False) nullable=False)
@ -904,11 +758,6 @@ class Fetch(Base):
__tablename__ = 'fetch' __tablename__ = 'fetch'
_dict_recurse = True
_dict_hide = {'user_email', 'user', 'last_check', 'error'}
_dict_mandatory = {'protocol', 'host', 'port', 'username', 'password'}
_dict_secret = {'password'}
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
user_email = db.Column(db.String(255), db.ForeignKey(User.email), user_email = db.Column(db.String(255), db.ForeignKey(User.email),
nullable=False) nullable=False)
@ -926,3 +775,124 @@ class Fetch(Base):
def __str__(self): def __str__(self):
return f'{self.protocol}{"s" if self.tls else ""}://{self.username}@{self.host}:{self.port}' return f'{self.protocol}{"s" if self.tls else ""}://{self.username}@{self.host}:{self.port}'
class MailuConfig:
""" Class which joins whole Mailu config for dumping
and loading
"""
# TODO: add sqlalchemy session updating (.add & .del)
class MailuCollection:
""" Provides dict- and list-like access to all instances
of a sqlalchemy model
"""
def __init__(self, model : db.Model):
self._model = model
@cached_property
def _items(self):
return {
inspect(item).identity: item
for item in self._model.query.all()
}
def __len__(self):
return len(self._items)
def __iter__(self):
return iter(self._items.values())
def __getitem__(self, key):
return self._items[key]
def __setitem__(self, key, item):
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
if key != inspect(item).identity:
raise ValueError(f'item identity != key {key!r}')
self._items[key] = item
def __delitem__(self, key):
del self._items[key]
def append(self, item):
""" list-like append """
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
key = inspect(item).identity
if key in self._items:
raise ValueError(f'item {key!r} already present in collection')
self._items[key] = item
def extend(self, items):
""" list-like extend """
add = {}
for item in items:
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
key = inspect(item).identity
if key in self._items:
raise ValueError(f'item {key!r} already present in collection')
add[key] = item
self._items.update(add)
def pop(self, *args):
""" list-like (no args) and dict-like (1 or 2 args) pop """
if args:
if len(args) > 2:
raise TypeError(f'pop expected at most 2 arguments, got {len(args)}')
return self._items.pop(*args)
else:
return self._items.popitem()[1]
def popitem(self):
""" dict-like popitem """
return self._items.popitem()
def remove(self, item):
""" list-like remove """
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
key = inspect(item).identity
if not key in self._items:
raise ValueError(f'item {key!r} not found in collection')
del self._items[key]
def clear(self):
""" dict-like clear """
while True:
try:
self.pop()
except IndexError:
break
def update(self, items):
""" dict-like update """
for key, item in items:
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
if key != inspect(item).identity:
raise ValueError(f'item identity != key {key!r}')
if key in self._items:
raise ValueError(f'item {key!r} already present in collection')
def setdefault(self, key, item=None):
""" dict-like setdefault """
if key in self._items:
return self._items[key]
if item is None:
return None
if not isinstance(item, self._model):
raise TypeError(f'expected {self._model.name}')
if key != inspect(item).identity:
raise ValueError(f'item identity != key {key!r}')
self._items[key] = item
return item
domains = MailuCollection(Domain)
relays = MailuCollection(Relay)
users = MailuCollection(User)
aliases = MailuCollection(Alias)
config = MailuCollection(Config)

View File

@ -1,13 +1,15 @@
""" Mailu marshmallow fields and schema
""" """
Mailu marshmallow schema
""" import re
from textwrap import wrap from textwrap import wrap
import re
import yaml import yaml
from marshmallow import post_dump, fields, Schema from marshmallow import pre_load, post_dump, fields, Schema
from marshmallow.exceptions import ValidationError
from marshmallow_sqlalchemy import SQLAlchemyAutoSchemaOpts
from flask_marshmallow import Marshmallow from flask_marshmallow import Marshmallow
from OpenSSL import crypto from OpenSSL import crypto
@ -15,9 +17,9 @@ from . import models, dkim
ma = Marshmallow() ma = Marshmallow()
# TODO:
# how to mark keys as "required" while unserializing (in certain use cases/API)? # TODO: how and where to mark keys as "required" while unserializing (on commandline, in api)?
# - fields withoud default => required # - fields without default => required
# - fields which are the primary key => unchangeable when updating # - fields which are the primary key => unchangeable when updating
@ -41,7 +43,7 @@ class RenderYAML:
return super().increase_indent(flow, False) return super().increase_indent(flow, False)
@staticmethod @staticmethod
def _update_dict(dict1, dict2): def _update_items(dict1, dict2):
""" sets missing keys in dict1 to values of dict2 """ sets missing keys in dict1 to values of dict2
""" """
for key, value in dict2.items(): for key, value in dict2.items():
@ -53,8 +55,8 @@ class RenderYAML:
def loads(cls, *args, **kwargs): def loads(cls, *args, **kwargs):
""" load yaml data from string """ load yaml data from string
""" """
cls._update_dict(kwargs, cls._load_defaults) cls._update_items(kwargs, cls._load_defaults)
return yaml.load(*args, **kwargs) return yaml.safe_load(*args, **kwargs)
_dump_defaults = { _dump_defaults = {
'Dumper': SpacedDumper, 'Dumper': SpacedDumper,
@ -65,13 +67,33 @@ class RenderYAML:
def dumps(cls, *args, **kwargs): def dumps(cls, *args, **kwargs):
""" dump yaml data to string """ dump yaml data to string
""" """
cls._update_dict(kwargs, cls._dump_defaults) cls._update_items(kwargs, cls._dump_defaults)
return yaml.dump(*args, **kwargs) return yaml.dump(*args, **kwargs)
### functions ###
def handle_email(data):
""" merge separate localpart and domain to email
"""
localpart = 'localpart' in data
domain = 'domain' in data
if 'email' in data:
if localpart or domain:
raise ValidationError('duplicate email and localpart/domain')
elif localpart and domain:
data['email'] = f'{data["localpart"]}@{data["domain"]}'
elif localpart or domain:
raise ValidationError('incomplete localpart/domain')
return data
### field definitions ### ### field definitions ###
class LazyString(fields.String): class LazyStringField(fields.String):
""" Field that serializes a "false" value to the empty string """ Field that serializes a "false" value to the empty string
""" """
@ -81,14 +103,27 @@ class LazyString(fields.String):
return value if value else '' return value if value else ''
class CommaSeparatedList(fields.Raw): class CommaSeparatedListField(fields.Raw):
""" Field that deserializes a string containing comma-separated values to """ Field that deserializes a string containing comma-separated values to
a list of strings a list of strings
""" """
# TODO: implement this
def _deserialize(self, value, attr, data, **kwargs):
""" deserialize comma separated string to list of strings
"""
# empty
if not value:
return []
# split string
if isinstance(value, str):
return list([item.strip() for item in value.split(',') if item.strip()])
else:
return value
class DkimKey(fields.String): class DkimKeyField(fields.String):
""" Field that serializes a dkim key to a list of strings (lines) and """ Field that serializes a dkim key to a list of strings (lines) and
deserializes a string or list of strings. deserializes a string or list of strings.
""" """
@ -120,7 +155,7 @@ class DkimKey(fields.String):
# only strings are allowed # only strings are allowed
if not isinstance(value, str): if not isinstance(value, str):
raise TypeError(f'invalid type: {type(value).__name__!r}') raise ValidationError(f'invalid type {type(value).__name__!r}')
# clean value (remove whitespace and header/footer) # clean value (remove whitespace and header/footer)
value = self._clean_re.sub('', value.strip()) value = self._clean_re.sub('', value.strip())
@ -133,6 +168,11 @@ class DkimKey(fields.String):
elif value == 'generate': elif value == 'generate':
return dkim.gen_key() return dkim.gen_key()
# remember some keydata for error message
keydata = value
if len(keydata) > 40:
keydata = keydata[:25] + '...' + keydata[-10:]
# wrap value into valid pem layout and check validity # wrap value into valid pem layout and check validity
value = ( value = (
'-----BEGIN PRIVATE KEY-----\n' + '-----BEGIN PRIVATE KEY-----\n' +
@ -142,26 +182,37 @@ class DkimKey(fields.String):
try: try:
crypto.load_privatekey(crypto.FILETYPE_PEM, value) crypto.load_privatekey(crypto.FILETYPE_PEM, value)
except crypto.Error as exc: except crypto.Error as exc:
raise ValueError('invalid dkim key') from exc raise ValidationError(f'invalid dkim key {keydata!r}') from exc
else: else:
return value return value
### schema definitions ### ### base definitions ###
class BaseOpts(SQLAlchemyAutoSchemaOpts):
""" Option class with sqla session
"""
def __init__(self, meta, ordered=False):
if not hasattr(meta, 'sqla_session'):
meta.sqla_session = models.db.session
super(BaseOpts, self).__init__(meta, ordered=ordered)
class BaseSchema(ma.SQLAlchemyAutoSchema): class BaseSchema(ma.SQLAlchemyAutoSchema):
""" Marshmallow base schema with custom exclude logic """ Marshmallow base schema with custom exclude logic
and option to hide sqla defaults and option to hide sqla defaults
""" """
OPTIONS_CLASS = BaseOpts
class Meta: class Meta:
""" Schema config """ """ Schema config """
model = None model = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# get and remove config from kwargs # context?
context = kwargs.get('context', {}) context = kwargs.get('context', {})
flags = set([key for key, value in context.items() if value is True])
# compile excludes # compile excludes
exclude = set(kwargs.get('exclude', [])) exclude = set(kwargs.get('exclude', []))
@ -171,8 +222,8 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
# add include_by_context # add include_by_context
if context is not None: if context is not None:
for ctx, what in getattr(self.Meta, 'include_by_context', {}).items(): for need, what in getattr(self.Meta, 'include_by_context', {}).items():
if not context.get(ctx): if not flags & set(need):
exclude |= set(what) exclude |= set(what)
# update excludes # update excludes
@ -192,8 +243,8 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
# hide by context # hide by context
self._hide_by_context = set() self._hide_by_context = set()
if context is not None: if context is not None:
for ctx, what in getattr(self.Meta, 'hide_by_context', {}).items(): for need, what in getattr(self.Meta, 'hide_by_context', {}).items():
if not context.get(ctx): if not flags & set(need):
self._hide_by_context |= set(what) self._hide_by_context |= set(what)
# init SQLAlchemyAutoSchema # init SQLAlchemyAutoSchema
@ -212,23 +263,26 @@ class BaseSchema(ma.SQLAlchemyAutoSchema):
if full or key not in self._exclude_by_value or value not in self._exclude_by_value[key] if full or key not in self._exclude_by_value or value not in self._exclude_by_value[key]
} }
# TODO: remove LazyString and fix model definition (comment should not be nullable) # TODO: remove LazyString and change model (IMHO comment should not be nullable)
comment = LazyString() comment = LazyStringField()
### schema definitions ###
class DomainSchema(BaseSchema): class DomainSchema(BaseSchema):
""" Marshmallow schema for Domain model """ """ Marshmallow schema for Domain model """
class Meta: class Meta:
""" Schema config """ """ Schema config """
model = models.Domain model = models.Domain
load_instance = True
include_relationships = True include_relationships = True
#include_fk = True
exclude = ['users', 'managers', 'aliases'] exclude = ['users', 'managers', 'aliases']
include_by_context = { include_by_context = {
'dns': {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'}, ('dns',): {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'},
} }
hide_by_context = { hide_by_context = {
'secrets': {'dkim_key'}, ('secrets',): {'dkim_key'},
} }
exclude_by_value = { exclude_by_value = {
'alternatives': [[]], 'alternatives': [[]],
@ -240,40 +294,20 @@ class DomainSchema(BaseSchema):
'dns_dmarc': [None], 'dns_dmarc': [None],
} }
dkim_key = DkimKey() dkim_key = DkimKeyField(allow_none=True)
dkim_publickey = fields.String(dump_only=True) dkim_publickey = fields.String(dump_only=True)
dns_mx = fields.String(dump_only=True) dns_mx = fields.String(dump_only=True)
dns_spf = fields.String(dump_only=True) dns_spf = fields.String(dump_only=True)
dns_dkim = fields.String(dump_only=True) dns_dkim = fields.String(dump_only=True)
dns_dmarc = fields.String(dump_only=True) dns_dmarc = fields.String(dump_only=True)
# _dict_types = {
# 'dkim_key': (bytes, type(None)),
# 'dkim_publickey': False,
# 'dns_mx': False,
# 'dns_spf': False,
# 'dns_dkim': False,
# 'dns_dmarc': False,
# }
class TokenSchema(BaseSchema): class TokenSchema(BaseSchema):
""" Marshmallow schema for Token model """ """ Marshmallow schema for Token model """
class Meta: class Meta:
""" Schema config """ """ Schema config """
model = models.Token model = models.Token
load_instance = True
# _dict_recurse = True
# _dict_hide = {'user', 'user_email'}
# _dict_mandatory = {'password'}
# id = db.Column(db.Integer(), primary_key=True)
# user_email = db.Column(db.String(255), db.ForeignKey(User.email),
# nullable=False)
# user = db.relationship(User,
# backref=db.backref('tokens', cascade='all, delete-orphan'))
# password = db.Column(db.String(255), nullable=False)
# ip = db.Column(db.String(255))
class FetchSchema(BaseSchema): class FetchSchema(BaseSchema):
@ -281,58 +315,57 @@ class FetchSchema(BaseSchema):
class Meta: class Meta:
""" Schema config """ """ Schema config """
model = models.Fetch model = models.Fetch
load_instance = True
include_by_context = { include_by_context = {
'full': {'last_check', 'error'}, ('full', 'import'): {'last_check', 'error'},
} }
hide_by_context = { hide_by_context = {
'secrets': {'password'}, ('secrets',): {'password'},
} }
# TODO: What about mandatory keys?
# _dict_mandatory = {'protocol', 'host', 'port', 'username', 'password'}
class UserSchema(BaseSchema): class UserSchema(BaseSchema):
""" Marshmallow schema for User model """ """ Marshmallow schema for User model """
class Meta: class Meta:
""" Schema config """ """ Schema config """
model = models.User model = models.User
load_instance = True
include_relationships = True include_relationships = True
exclude = ['localpart', 'domain', 'quota_bytes_used'] exclude = ['localpart', 'domain', 'quota_bytes_used']
exclude_by_value = { exclude_by_value = {
'forward_destination': [[]], 'forward_destination': [[]],
'tokens': [[]], 'tokens': [[]],
'manager_of': [[]],
'reply_enddate': ['2999-12-31'], 'reply_enddate': ['2999-12-31'],
'reply_startdate': ['1900-01-01'], 'reply_startdate': ['1900-01-01'],
} }
@pre_load
def _handle_password(self, data, many, **kwargs): # pylint: disable=unused-argument
data = handle_email(data)
if 'password' in data:
if 'password_hash' in data or 'hash_scheme' in data:
raise ValidationError('ambigous key password and password_hash/hash_scheme')
# check (hashed) password
password = data['password']
if password.startswith('{') and '}' in password:
scheme = password[1:password.index('}')]
if scheme not in self.Meta.model.scheme_dict:
raise ValidationError(f'invalid password scheme {scheme!r}')
else:
raise ValidationError(f'invalid hashed password {password!r}')
elif 'password_hash' in data and 'hash_scheme' in data:
if data['hash_scheme'] not in self.Meta.model.scheme_dict:
raise ValidationError(f'invalid password scheme {scheme!r}')
data['password'] = '{'+data['hash_scheme']+'}'+ data['password_hash']
del data['hash_scheme']
del data['password_hash']
return data
tokens = fields.Nested(TokenSchema, many=True) tokens = fields.Nested(TokenSchema, many=True)
fetches = fields.Nested(FetchSchema, many=True) fetches = fields.Nested(FetchSchema, many=True)
# TODO: deserialize password/password_hash! What about mandatory keys?
# _dict_mandatory = {'localpart', 'domain', 'password'}
# @classmethod
# def _dict_input(cls, data):
# Email._dict_input(data)
# # handle password
# if 'password' in data:
# if 'password_hash' in data or 'hash_scheme' in data:
# raise ValueError('ambigous key password and password_hash/hash_scheme')
# # check (hashed) password
# password = data['password']
# if password.startswith('{') and '}' in password:
# scheme = password[1:password.index('}')]
# if scheme not in cls.scheme_dict:
# raise ValueError(f'invalid password scheme {scheme!r}')
# else:
# raise ValueError(f'invalid hashed password {password!r}')
# elif 'password_hash' in data and 'hash_scheme' in data:
# if data['hash_scheme'] not in cls.scheme_dict:
# raise ValueError(f'invalid password scheme {scheme!r}')
# data['password'] = '{'+data['hash_scheme']+'}'+ data['password_hash']
# del data['hash_scheme']
# del data['password_hash']
class AliasSchema(BaseSchema): class AliasSchema(BaseSchema):
@ -340,20 +373,18 @@ class AliasSchema(BaseSchema):
class Meta: class Meta:
""" Schema config """ """ Schema config """
model = models.Alias model = models.Alias
load_instance = True
exclude = ['localpart'] exclude = ['localpart']
exclude_by_value = { exclude_by_value = {
'destination': [[]], 'destination': [[]],
} }
# TODO: deserialize destination! @pre_load
# @staticmethod def _handle_password(self, data, many, **kwargs): # pylint: disable=unused-argument
# def _dict_input(data): return handle_email(data)
# Email._dict_input(data)
# # handle comma delimited string for backwards compability destination = CommaSeparatedListField()
# dst = data.get('destination')
# if type(dst) is str:
# data['destination'] = list([adr.strip() for adr in dst.split(',')])
class ConfigSchema(BaseSchema): class ConfigSchema(BaseSchema):
@ -361,6 +392,7 @@ class ConfigSchema(BaseSchema):
class Meta: class Meta:
""" Schema config """ """ Schema config """
model = models.Config model = models.Config
load_instance = True
class RelaySchema(BaseSchema): class RelaySchema(BaseSchema):
@ -368,45 +400,17 @@ class RelaySchema(BaseSchema):
class Meta: class Meta:
""" Schema config """ """ Schema config """
model = models.Relay model = models.Relay
load_instance = True
class MailuSchema(Schema): class MailuSchema(Schema):
""" Marshmallow schema for Mailu config """ """ Marshmallow schema for complete Mailu config """
class Meta: class Meta:
""" Schema config """ """ Schema config """
render_module = RenderYAML render_module = RenderYAML
domains = fields.Nested(DomainSchema, many=True) domains = fields.Nested(DomainSchema, many=True)
relays = fields.Nested(RelaySchema, many=True) relays = fields.Nested(RelaySchema, many=True)
users = fields.Nested(UserSchema, many=True) users = fields.Nested(UserSchema, many=True)
aliases = fields.Nested(AliasSchema, many=True) aliases = fields.Nested(AliasSchema, many=True)
config = fields.Nested(ConfigSchema, many=True) config = fields.Nested(ConfigSchema, many=True)
### config class ###
class MailuConfig:
""" Class which joins whole Mailu config for dumping
"""
_models = {
'domains': models.Domain,
'relays': models.Relay,
'users': models.User,
'aliases': models.Alias,
# 'config': models.Config,
}
def __init__(self, sections):
if sections:
for section in sections:
if section not in self._models:
raise ValueError(f'Unknown section: {section!r}')
self._sections = set(sections)
else:
self._sections = set(self._models.keys())
def __getattr__(self, section):
if section in self._sections:
return self._models[section].query.all()
else:
raise AttributeError