Skip to content

Commit

Permalink
Refactor HomeserverConfig so it can be typechecked (matrix-org#6137)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkowl authored Oct 10, 2019
1 parent def5413 commit f743108
Show file tree
Hide file tree
Showing 37 changed files with 415 additions and 94 deletions.
1 change: 1 addition & 0 deletions changelog.d/6137.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor configuration loading to allow better typechecking.
16 changes: 12 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ plugins=mypy_zope:plugin
follow_imports=skip
mypy_path=stubs

[mypy-synapse.config.homeserver]
# this is a mess because of the metaclass shenanigans
ignore_errors = True

[mypy-zope]
ignore_missing_imports = True

Expand Down Expand Up @@ -52,3 +48,15 @@ ignore_missing_imports = True

[mypy-signedjson.*]
ignore_missing_imports = True

[mypy-prometheus_client.*]
ignore_missing_imports = True

[mypy-service_identity.*]
ignore_missing_imports = True

[mypy-daemonize]
ignore_missing_imports = True

[mypy-sentry_sdk]
ignore_missing_imports = True
191 changes: 148 additions & 43 deletions synapse/config/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import argparse
import errno
import os
from collections import OrderedDict
from textwrap import dedent
from typing import Any, MutableMapping, Optional

from six import integer_types

Expand Down Expand Up @@ -51,7 +53,56 @@ class ConfigError(Exception):
"""


def path_exists(file_path):
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
checking if the file exists (for example, if there is a perms error on
the parent dir).
Returns:
bool: True if the file exists; False if not.
"""
try:
os.stat(file_path)
return True
except OSError as e:
if e.errno != errno.ENOENT:
raise e
return False


class Config(object):
"""
A configuration section, containing configuration keys and values.
Attributes:
section (str): The section title of this config object, such as
"tls" or "logger". This is used to refer to it on the root
logger (for example, `config.tls.some_option`). Must be
defined in subclasses.
"""

section = None

def __init__(self, root_config=None):
self.root = root_config

def __getattr__(self, item: str) -> Any:
"""
Try and fetch a configuration option that does not exist on this class.
This is so that existing configs that rely on `self.value`, where value
is actually from a different config section, continue to work.
"""
if item in ["generate_config_section", "read_config"]:
raise AttributeError(item)

if self.root is None:
raise AttributeError(item)
else:
return self.root._get_unclassed_config(self.section, item)

@staticmethod
def parse_size(value):
if isinstance(value, integer_types):
Expand Down Expand Up @@ -88,22 +139,7 @@ def abspath(file_path):

@classmethod
def path_exists(cls, file_path):
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
checking if the file exists (for example, if there is a perms error on
the parent dir).
Returns:
bool: True if the file exists; False if not.
"""
try:
os.stat(file_path)
return True
except OSError as e:
if e.errno != errno.ENOENT:
raise e
return False
return path_exists(file_path)

@classmethod
def check_file(cls, file_path, config_name):
Expand Down Expand Up @@ -136,42 +172,106 @@ def read_file(cls, file_path, config_name):
with open(file_path) as file_stream:
return file_stream.read()

def invoke_all(self, name, *args, **kargs):
"""Invoke all instance methods with the given name and arguments in the
class's MRO.

class RootConfig(object):
"""
Holder of an application's configuration.
What configuration this object holds is defined by `config_classes`, a list
of Config classes that will be instantiated and given the contents of a
configuration file to read. They can then be accessed on this class by their
section name, defined in the Config or dynamically set to be the name of the
class, lower-cased and with "Config" removed.
"""

config_classes = []

def __init__(self):
self._configs = OrderedDict()

for config_class in self.config_classes:
if config_class.section is None:
raise ValueError("%r requires a section name" % (config_class,))

try:
conf = config_class(self)
except Exception as e:
raise Exception("Failed making %s: %r" % (config_class.section, e))
self._configs[config_class.section] = conf

def __getattr__(self, item: str) -> Any:
"""
Redirect lookups on this object either to config objects, or values on
config objects, so that `config.tls.blah` works, as well as legacy uses
of things like `config.server_name`. It will first look up the config
section name, and then values on those config classes.
"""
if item in self._configs.keys():
return self._configs[item]

return self._get_unclassed_config(None, item)

def _get_unclassed_config(self, asking_section: Optional[str], item: str):
"""
Fetch a config value from one of the instantiated config classes that
has not been fetched directly.
Args:
asking_section: If this check is coming from a Config child, which
one? This section will not be asked if it has the value.
item: The configuration value key.
Raises:
AttributeError if no config classes have the config key. The body
will contain what sections were checked.
"""
for key, val in self._configs.items():
if key == asking_section:
continue

if item in dir(val):
return getattr(val, item)

raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))

def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
"""
Invoke a function on all instantiated config objects this RootConfig is
configured to use.
Args:
name (str): Name of function to invoke
func_name: Name of function to invoke
*args
**kwargs
Returns:
list: The list of the return values from each method called
ordered dictionary of config section name and the result of the
function from it.
"""
results = []
for cls in type(self).mro():
if name in cls.__dict__:
results.append(getattr(cls, name)(self, *args, **kargs))
return results
res = OrderedDict()

for name, config in self._configs.items():
if hasattr(config, func_name):
res[name] = getattr(config, func_name)(*args, **kwargs)

return res

@classmethod
def invoke_all_static(cls, name, *args, **kargs):
"""Invoke all static methods with the given name and arguments in the
class's MRO.
def invoke_all_static(cls, func_name: str, *args, **kwargs):
"""
Invoke a static function on config objects this RootConfig is
configured to use.
Args:
name (str): Name of function to invoke
func_name: Name of function to invoke
*args
**kwargs
Returns:
list: The list of the return values from each method called
ordered dictionary of config section name and the result of the
function from it.
"""
results = []
for c in cls.mro():
if name in c.__dict__:
results.append(getattr(c, name)(*args, **kargs))
return results
for config in cls.config_classes:
if hasattr(config, func_name):
getattr(config, func_name)(*args, **kwargs)

def generate_config(
self,
Expand All @@ -187,7 +287,8 @@ def generate_config(
tls_private_key_path=None,
acme_domain=None,
):
"""Build a default configuration file
"""
Build a default configuration file
This is used when the user explicitly asks us to generate a config file
(eg with --generate_config).
Expand Down Expand Up @@ -242,6 +343,7 @@ def generate_config(
Returns:
str: the yaml config file
"""

return "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
Expand All @@ -257,7 +359,7 @@ def generate_config(
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,
)
).values()
)

@classmethod
Expand Down Expand Up @@ -444,7 +546,7 @@ def load_or_generate_config(cls, description, argv):
)

(config_path,) = config_files
if not cls.path_exists(config_path):
if not path_exists(config_path):
print("Generating config file %s" % (config_path,))

if config_args.data_directory:
Expand All @@ -469,7 +571,7 @@ def load_or_generate_config(cls, description, argv):
open_private_ports=config_args.open_private_ports,
)

if not cls.path_exists(config_dir_path):
if not path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
config_file.write("# vim:ft=yaml\n\n")
Expand Down Expand Up @@ -518,7 +620,7 @@ def load_or_generate_config(cls, description, argv):

return obj

def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
"""Read the information from the config dict into this Config object.
Args:
Expand Down Expand Up @@ -607,3 +709,6 @@ def find_config_files(search_paths):
else:
config_files.append(config_path)
return config_files


__all__ = ["Config", "RootConfig"]
Loading

0 comments on commit f743108

Please sign in to comment.