6
0
mirror of https://github.com/ChristianLight/tutor.git synced 2024-12-12 14:17:46 +00:00

refactor: better config type checking

I stumbled upon a bug that should have been detected by the type
checking. Turns out, considering that config is of type Dict[str, Any]
means that we can use just any method on all config values -- which is
terrible. I discovered this after I set `config["PLUGINS"] = None`:
this triggered a crash when I enabled a plugin.
We resolve this by making the Config type more explicit. We also take
the opportunity to remove a few cast statements.
This commit is contained in:
Régis Behmo 2021-04-06 12:09:00 +02:00 committed by Régis Behmo
parent 887ba31e09
commit 336cb79fa8
22 changed files with 269 additions and 201 deletions

View File

@ -1,10 +1,10 @@
from typing import Any, Dict
import unittest
from unittest.mock import Mock, patch
import tempfile
from tutor import config as tutor_config
from tutor import interactive
from tutor.types import get_typed, Config
class ConfigTests(unittest.TestCase):
@ -13,13 +13,13 @@ class ConfigTests(unittest.TestCase):
self.assertNotIn("TUTOR_VERSION", defaults)
def test_merge(self) -> None:
config1 = {"x": "y"}
config2 = {"x": "z"}
config1: Config = {"x": "y"}
config2: Config = {"x": "z"}
tutor_config.merge(config1, config2)
self.assertEqual({"x": "y"}, config1)
def test_merge_render(self) -> None:
config: Dict[str, Any] = {}
config: Config = {}
defaults = tutor_config.load_defaults()
with patch.object(tutor_config.utils, "random_string", return_value="abcd"):
tutor_config.merge(config, defaults)
@ -62,13 +62,13 @@ class ConfigTests(unittest.TestCase):
config, defaults = interactive.load_all(rootdir, interactive=False)
self.assertIn("MYSQL_ROOT_PASSWORD", config)
self.assertEqual(8, len(config["MYSQL_ROOT_PASSWORD"]))
self.assertEqual(8, len(get_typed(config, "MYSQL_ROOT_PASSWORD", str)))
self.assertNotIn("LMS_HOST", config)
self.assertEqual("www.myopenedx.com", defaults["LMS_HOST"])
self.assertEqual("studio.{{ LMS_HOST }}", defaults["CMS_HOST"])
def test_is_service_activated(self) -> None:
config = {"RUN_SERVICE1": True, "RUN_SERVICE2": False}
config: Config = {"RUN_SERVICE1": True, "RUN_SERVICE2": False}
self.assertTrue(tutor_config.is_service_activated(config, "service1"))
self.assertFalse(tutor_config.is_service_activated(config, "service2"))

View File

@ -1,6 +1,5 @@
import os
import tempfile
from typing import Any, Dict
import unittest
from unittest.mock import patch, Mock
@ -8,6 +7,7 @@ from tutor import config as tutor_config
from tutor import env
from tutor import fmt
from tutor import exceptions
from tutor.types import Config
class EnvTests(unittest.TestCase):
@ -55,7 +55,7 @@ class EnvTests(unittest.TestCase):
self.assertRaises(exceptions.TutorError, env.render_str, {}, "hello {{ name }}")
def test_render_file(self) -> None:
config: Dict[str, Any] = {}
config: Config = {}
tutor_config.merge(config, tutor_config.load_defaults())
config["MYSQL_ROOT_PASSWORD"] = "testpassword"
rendered = env.render_file(config, "hooks", "mysql", "init")
@ -125,7 +125,7 @@ class EnvTests(unittest.TestCase):
f.write("Hello my ID is {{ ID }}")
# Create configuration
config = {"ID": "abcd"}
config: Config = {"ID": "abcd"}
# Render templates
with patch.object(
@ -162,7 +162,7 @@ class EnvTests(unittest.TestCase):
f.write("some content")
# Load env once
config: Dict[str, Any] = {"PLUGINS": []}
config: Config = {"PLUGINS": []}
env1 = env.Renderer.instance(config).environment
with patch.object(
@ -171,7 +171,7 @@ class EnvTests(unittest.TestCase):
return_value=[plugin1],
):
# Load env a second time
config["PLUGINS"].append("myplugin")
config["PLUGINS"] = ["myplugin"]
env2 = env.Renderer.instance(config).environment
self.assertNotIn("plugin1/myplugin.txt", env1.loader.list_templates()) # type: ignore

View File

@ -1,10 +1,11 @@
import unittest
from tutor import images
from tutor.types import Config
class ImagesTests(unittest.TestCase):
def test_get_tag(self) -> None:
config = {
config: Config = {
"DOCKER_IMAGE_OPENEDX": "registry/openedx",
"DOCKER_IMAGE_OPENEDX_DEV": "registry/openedxdev",
}

View File

@ -1,4 +1,3 @@
from typing import Any, Dict
import unittest
from unittest.mock import Mock, patch
@ -6,6 +5,7 @@ from tutor import config as tutor_config
from tutor import exceptions
from tutor import fmt
from tutor import plugins
from tutor.types import get_typed, Config
class PluginsTests(unittest.TestCase):
@ -37,26 +37,26 @@ class PluginsTests(unittest.TestCase):
)
def test_enable(self) -> None:
config: Dict[str, Any] = {plugins.CONFIG_KEY: []}
config: Config = {plugins.CONFIG_KEY: []}
with patch.object(plugins, "is_installed", return_value=True):
plugins.enable(config, "plugin2")
plugins.enable(config, "plugin1")
self.assertEqual(["plugin1", "plugin2"], config[plugins.CONFIG_KEY])
def test_enable_twice(self) -> None:
config: Dict[str, Any] = {plugins.CONFIG_KEY: []}
config: Config = {plugins.CONFIG_KEY: []}
with patch.object(plugins, "is_installed", return_value=True):
plugins.enable(config, "plugin1")
plugins.enable(config, "plugin1")
self.assertEqual(["plugin1"], config[plugins.CONFIG_KEY])
def test_enable_not_installed_plugin(self) -> None:
config: Dict[str, Any] = {"PLUGINS": []}
config: Config = {"PLUGINS": []}
with patch.object(plugins, "is_installed", return_value=False):
self.assertRaises(exceptions.TutorError, plugins.enable, config, "plugin1")
def test_disable(self) -> None:
config: Dict[str, Any] = {"PLUGINS": ["plugin1", "plugin2"]}
config: Config = {"PLUGINS": ["plugin1", "plugin2"]}
with patch.object(fmt, "STDOUT"):
plugins.disable(config, "plugin1")
self.assertEqual(["plugin2"], config["PLUGINS"])
@ -75,14 +75,14 @@ class PluginsTests(unittest.TestCase):
)
],
):
config = {"PLUGINS": ["plugin1"], "KEY": "value"}
config: Config = {"PLUGINS": ["plugin1"], "KEY": "value"}
with patch.object(fmt, "STDOUT"):
plugins.disable(config, "plugin1")
self.assertEqual([], config["PLUGINS"])
self.assertNotIn("KEY", config)
def test_none_plugins(self) -> None:
config = {plugins.CONFIG_KEY: None}
config: Config = {plugins.CONFIG_KEY: None}
self.assertFalse(plugins.is_enabled(config, "myplugin"))
def test_patches(self) -> None:
@ -107,11 +107,11 @@ class PluginsTests(unittest.TestCase):
self.assertEqual([], patches)
def test_configure(self) -> None:
config = {"ID": "id"}
defaults: Dict[str, Any] = {}
config: Config = {"ID": "id"}
defaults: Config = {}
class plugin1:
config = {
config: Config = {
"add": {"PARAM1": "value1", "PARAM2": "value2"},
"set": {"PARAM3": "value3"},
"defaults": {"PARAM4": "value4"},
@ -136,10 +136,10 @@ class PluginsTests(unittest.TestCase):
self.assertEqual({"PLUGIN1_PARAM4": "value4"}, defaults)
def test_configure_set_does_not_override(self) -> None:
config = {"ID": "oldid"}
config: Config = {"ID": "oldid"}
class plugin1:
config = {"set": {"ID": "newid"}}
config: Config = {"set": {"ID": "newid"}}
with patch.object(
plugins.Plugins,
@ -151,10 +151,10 @@ class PluginsTests(unittest.TestCase):
self.assertEqual({"ID": "oldid"}, config)
def test_configure_set_random_string(self) -> None:
config: Dict[str, Any] = {}
config: Config = {}
class plugin1:
config = {"set": {"PARAM1": "{{ 128|random_string }}"}}
config: Config = {"set": {"PARAM1": "{{ 128|random_string }}"}}
with patch.object(
plugins.Plugins,
@ -162,14 +162,14 @@ class PluginsTests(unittest.TestCase):
return_value=[plugins.BasePlugin("plugin1", plugin1)],
):
tutor_config.load_plugins(config, {})
self.assertEqual(128, len(config["PARAM1"]))
self.assertEqual(128, len(get_typed(config, "PARAM1", str)))
def test_configure_default_value_with_previous_definition(self) -> None:
config: Dict[str, Any] = {}
defaults = {"PARAM1": "value"}
config: Config = {}
defaults: Config = {"PARAM1": "value"}
class plugin1:
config = {"defaults": {"PARAM2": "{{ PARAM1 }}"}}
config: Config = {"defaults": {"PARAM2": "{{ PARAM1 }}"}}
with patch.object(
plugins.Plugins,
@ -180,10 +180,10 @@ class PluginsTests(unittest.TestCase):
self.assertEqual("{{ PARAM1 }}", defaults["PLUGIN1_PARAM2"])
def test_configure_add_twice(self) -> None:
config: Dict[str, Any] = {}
config: Config = {}
class plugin1:
config = {"add": {"PARAM1": "{{ 10|random_string }}"}}
config: Config = {"add": {"PARAM1": "{{ 10|random_string }}"}}
with patch.object(
plugins.Plugins,
@ -191,14 +191,14 @@ class PluginsTests(unittest.TestCase):
return_value=[plugins.BasePlugin("plugin1", plugin1)],
):
tutor_config.load_plugins(config, {})
value1 = config["PLUGIN1_PARAM1"]
value1 = get_typed(config, "PLUGIN1_PARAM1", str)
with patch.object(
plugins.Plugins,
"iter_enabled",
return_value=[plugins.BasePlugin("plugin1", plugin1)],
):
tutor_config.load_plugins(config, {})
value2 = config["PLUGIN1_PARAM1"]
value2 = get_typed(config, "PLUGIN1_PARAM1", str)
self.assertEqual(10, len(value1))
self.assertEqual(10, len(value2))
@ -218,10 +218,10 @@ class PluginsTests(unittest.TestCase):
)
def test_plugins_are_updated_on_config_change(self) -> None:
config: Dict[str, Any] = {"PLUGINS": []}
config: Config = {"PLUGINS": []}
plugins1 = plugins.Plugins(config)
self.assertEqual(0, len(list(plugins1.iter_enabled())))
config["PLUGINS"].append("plugin1")
config["PLUGINS"] = ["plugin1"]
with patch.object(
plugins.Plugins,
"iter_installed",

View File

@ -1,17 +1,18 @@
import os
from typing import Any, Callable, Dict, List, Tuple
from typing import Callable, List, Tuple
import click
from mypy_extensions import VarArg
from .exceptions import TutorError
from .types import Config
from .utils import get_user_id
def create(
root: str,
config: Dict[str, Any],
docker_compose_func: Callable[[str, Dict[str, Any], VarArg(str)], int],
config: Config,
docker_compose_func: Callable[[str, Config, VarArg(str)], int],
service: str,
path: str,
) -> str:

View File

@ -1,5 +1,3 @@
from typing import Dict
import click
from .compose import ComposeJobRunner
@ -7,6 +5,7 @@ from .local import docker_compose as local_docker_compose
from .. import config as tutor_config
from .. import env as tutor_env
from .. import fmt
from ..types import Config
from .context import Context
@ -28,7 +27,7 @@ def build(context: Context, mode: str) -> None:
)
def build_command(config: Dict[str, str], target: str) -> str:
def build_command(config: Config, target: str) -> str:
gradle_target = {
"debug": "assembleProdDebuggable",
"release": "assembleProdRelease",

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Callable, Dict, List
from typing import Callable, List
import click
from mypy_extensions import VarArg
@ -11,6 +11,7 @@ from ..exceptions import TutorError
from .. import fmt
from .. import jobs
from .. import serialize
from ..types import Config
from .. import utils
from .context import Context
@ -19,8 +20,8 @@ class ComposeJobRunner(jobs.BaseJobRunner):
def __init__(
self,
root: str,
config: Dict[str, Any],
docker_compose_func: Callable[[str, Dict[str, Any], VarArg(str)], int],
config: Config,
docker_compose_func: Callable[[str, Config, VarArg(str)], int],
):
super().__init__(root, config)
self.docker_compose_func = docker_compose_func

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import List
import click
@ -8,6 +8,7 @@ from .. import exceptions
from .. import fmt
from .. import interactive as interactive_config
from .. import serialize
from ..types import Config
from .context import Context
@ -40,7 +41,7 @@ def config_command() -> None:
)
@click.pass_obj
def save(
context: Context, interactive: bool, set_vars: Dict[str, Any], unset_vars: List[str]
context: Context, interactive: bool, set_vars: Config, unset_vars: List[str]
) -> None:
config, defaults = interactive_config.load_all(
context.root, interactive=interactive
@ -91,7 +92,7 @@ def printvalue(context: Context, key: str) -> None:
config = tutor_config.load(context.root)
try:
# Note that this will incorrectly print None values
fmt.echo(config[key])
fmt.echo(str(config[key]))
except KeyError as e:
raise exceptions.TutorError(
"Missing configuration value: {}".format(key)

View File

@ -1,9 +1,7 @@
from typing import Any, Dict
from ..types import Config
def unimplemented_docker_compose(
root: str, config: Dict[str, Any], *command: str
) -> int:
def unimplemented_docker_compose(root: str, config: Config, *command: str) -> int:
raise NotImplementedError
@ -13,5 +11,5 @@ class Context:
self.root = root
self.docker_compose_func = unimplemented_docker_compose
def docker_compose(self, root: str, config: Dict[str, Any], *command: str) -> int:
def docker_compose(self, root: str, config: Config, *command: str) -> int:
return self.docker_compose_func(root, config, *command)

View File

@ -1,17 +1,18 @@
import os
from typing import Any, Dict, List
from typing import List
import click
from .. import config as tutor_config
from .. import env as tutor_env
from .. import fmt
from ..types import Config
from .. import utils
from . import compose
from .context import Context
def docker_compose(root: str, config: Dict[str, Any], *command: str) -> int:
def docker_compose(root: str, config: Config, *command: str) -> int:
"""
Run docker-compose with dev arguments.
"""

View File

@ -1,11 +1,13 @@
from typing import cast, Any, Dict, Iterator, List, Tuple
from typing import Iterator, List, Tuple
import click
from .. import config as tutor_config
from .. import env as tutor_env
from .. import exceptions
from .. import images
from .. import plugins
from ..types import Config
from .. import utils
from .context import Context
@ -105,7 +107,7 @@ def printtag(context: Context, image_names: List[str]) -> None:
print(tag)
def build_image(root: str, config: Dict[str, Any], image: str, *args: str) -> None:
def build_image(root: str, config: Config, image: str, *args: str) -> None:
# Build base images
for img, tag in iter_images(config, image, BASE_IMAGE_NAMES):
images.build(tutor_env.pathjoin(root, "build", img), tag, *args)
@ -122,14 +124,14 @@ def build_image(root: str, config: Dict[str, Any], image: str, *args: str) -> No
images.build(tutor_env.pathjoin(root, "build", img), tag, *dev_build_arg, *args)
def pull_image(config: Dict[str, Any], image: str) -> None:
def pull_image(config: Config, image: str) -> None:
for _img, tag in iter_images(config, image, all_image_names(config)):
images.pull(tag)
for _plugin, _img, tag in iter_plugin_images(config, image, "remote-image"):
images.pull(tag)
def push_image(config: Dict[str, Any], image: str) -> None:
def push_image(config: Config, image: str) -> None:
for _img, tag in iter_images(config, image, BASE_IMAGE_NAMES):
images.push(tag)
for _plugin, _img, tag in iter_plugin_images(config, image, "remote-image"):
@ -137,7 +139,7 @@ def push_image(config: Dict[str, Any], image: str) -> None:
def iter_images(
config: Dict[str, Any], image: str, image_list: List[str]
config: Config, image: str, image_list: List[str]
) -> Iterator[Tuple[str, str]]:
for img in image_list:
if image in [img, "all"]:
@ -146,21 +148,26 @@ def iter_images(
def iter_plugin_images(
config: Dict[str, Any], image: str, hook_name: str
config: Config, image: str, hook_name: str
) -> Iterator[Tuple[str, str, str]]:
for plugin, hook in plugins.iter_hooks(config, hook_name):
hook = cast(Dict[str, str], hook)
if not isinstance(hook, dict):
raise exceptions.TutorError(
"Invalid hook '{}': expected dict, got {}".format(
hook_name, hook.__class__
)
)
for img, tag in hook.items():
if image in [img, "all"]:
tag = tutor_env.render_str(config, tag)
yield plugin, img, tag
def all_image_names(config: Dict[str, Any]) -> List[str]:
def all_image_names(config: Config) -> List[str]:
return BASE_IMAGE_NAMES + vendor_image_names(config)
def vendor_image_names(config: Dict[str, Any]) -> List[str]:
def vendor_image_names(config: Config) -> List[str]:
vendor_images = VENDOR_IMAGES[:]
for image in VENDOR_IMAGES:
if not config.get("RUN_" + image.upper(), True):

View File

@ -1,6 +1,6 @@
from datetime import datetime
from time import sleep
from typing import cast, Any, Dict, List, Optional, Type
from typing import Any, List, Optional, Type
import click
@ -11,6 +11,7 @@ from .. import fmt
from .. import interactive as interactive_config
from .. import jobs
from .. import serialize
from ..types import Config, get_typed
from .. import utils
from .context import Context
@ -50,7 +51,11 @@ class K8sJobRunner(jobs.BaseJobRunner):
def load_job(self, name: str) -> Any:
all_jobs = self.render("k8s", "jobs.yml")
for job in serialize.load_all(all_jobs):
job_name = cast(str, job["metadata"]["name"])
job_name = job["metadata"]["name"]
if not isinstance(job_name, str):
raise exceptions.TutorError(
"Invalid job name: '{}'. Expected str.".format(job_name)
)
if job_name == name:
return job
raise ValueError("Could not find job '{}'".format(name))
@ -64,7 +69,7 @@ class K8sJobRunner(jobs.BaseJobRunner):
api = K8sClients.instance().batch_api
return [
job.metadata.name
for job in api.list_namespaced_job(self.config["K8S_NAMESPACE"]).items
for job in api.list_namespaced_job(k8s_namespace(self.config)).items
if job.status.active
]
@ -139,7 +144,7 @@ class K8sJobRunner(jobs.BaseJobRunner):
""" kubectl logs --namespace={namespace} --follow $(kubectl get --namespace={namespace} pods """
"""--selector=job-name={job_name} -o=jsonpath="{{.items[0].metadata.name}}")\n\n"""
"Waiting for job completion..."
).format(job_name=job_name, namespace=self.config["K8S_NAMESPACE"])
).format(job_name=job_name, namespace=k8s_namespace(self.config))
fmt.echo_info(message)
# Wait for completion
@ -257,14 +262,15 @@ def reboot(context: click.Context) -> None:
context.invoke(start)
def resource_selector(config: Dict[str, str], *selectors: str) -> List[str]:
def resource_selector(config: Config, *selectors: str) -> List[str]:
"""
Convenient utility for filtering only the resources that belong to this project.
"""
selector = ",".join(
["app.kubernetes.io/instance=openedx-" + config["ID"]] + list(selectors)
["app.kubernetes.io/instance=openedx-" + get_typed(config, "ID", str)]
+ list(selectors)
)
return ["--namespace", config["K8S_NAMESPACE"], "--selector=" + selector]
return ["--namespace", k8s_namespace(config), "--selector=" + selector]
@click.command(help="Completely delete an existing platform")
@ -398,7 +404,7 @@ def upgrade(context: Context, from_version: str) -> None:
running_version = "koa"
def upgrade_from_ironwood(config: Dict[str, Any]) -> None:
def upgrade_from_ironwood(config: Config) -> None:
if not config["RUN_MONGODB"]:
fmt.echo_info(
"You are not running MongDB (RUN_MONGODB=false). It is your "
@ -425,7 +431,7 @@ your MongoDb cluster from v3.2 to v3.6. You should run something similar to:
fmt.echo_info(message)
def upgrade_from_juniper(config: Dict[str, Any]) -> None:
def upgrade_from_juniper(config: Config) -> None:
if not config["RUN_MYSQL"]:
fmt.echo_info(
"You are not running MySQL (RUN_MYSQL=false). It is your "
@ -446,7 +452,7 @@ your MySQL database from v5.6 to v5.7. You should run something similar to:
def kubectl_exec(
config: Dict[str, Any], service: str, command: str, attach: bool = False
config: Config, service: str, command: str, attach: bool = False
) -> int:
selector = "app.kubernetes.io/name={}".format(service)
pods = K8sClients.instance().core_api.list_namespaced_pod(
@ -464,7 +470,7 @@ def kubectl_exec(
"exec",
*attach_opts,
"--namespace",
config["K8S_NAMESPACE"],
k8s_namespace(config),
pod_name,
"--",
"sh",
@ -474,7 +480,7 @@ def kubectl_exec(
)
def wait_for_pod_ready(config: Dict[str, str], service: str) -> None:
def wait_for_pod_ready(config: Config, service: str) -> None:
fmt.echo_info("Waiting for a {} pod to be ready...".format(service))
utils.kubectl(
"wait",
@ -485,6 +491,10 @@ def wait_for_pod_ready(config: Dict[str, str], service: str) -> None:
)
def k8s_namespace(config: Config) -> str:
return get_typed(config, "K8S_NAMESPACE", str)
k8s.add_command(quickstart)
k8s.add_command(start)
k8s.add_command(stop)

View File

@ -1,18 +1,18 @@
import os
from typing import Dict, Any
import click
from .. import config as tutor_config
from .. import env as tutor_env
from .. import fmt
from ..types import get_typed, Config
from .. import utils
from . import compose
from .config import save as config_save_command
from .context import Context
def docker_compose(root: str, config: Dict[str, Any], *command: str) -> int:
def docker_compose(root: str, config: Config, *command: str) -> int:
"""
Run docker-compose with local and production yml files.
"""
@ -27,7 +27,7 @@ def docker_compose(root: str, config: Dict[str, Any], *command: str) -> int:
tutor_env.pathjoin(root, "local", "docker-compose.prod.yml"),
*args,
"--project-name",
config["LOCAL_PROJECT_NAME"],
get_typed(config, "LOCAL_PROJECT_NAME", str),
*command
)
@ -118,7 +118,7 @@ Are you sure you want to continue?"""
running_version = "koa"
def upgrade_from_ironwood(context: click.Context, config: Dict[str, Any]) -> None:
def upgrade_from_ironwood(context: click.Context, config: Config) -> None:
click.echo(fmt.title("Upgrading from Ironwood"))
tutor_env.save(context.obj.root, config)
@ -166,7 +166,7 @@ def upgrade_from_ironwood(context: click.Context, config: Dict[str, Any]) -> Non
context.invoke(compose.stop)
def upgrade_from_juniper(context: click.Context, config: Dict[str, Any]) -> None:
def upgrade_from_juniper(context: click.Context, config: Config) -> None:
click.echo(fmt.title("Upgrading from Juniper"))
tutor_env.save(context.obj.root, config)

View File

@ -4,7 +4,7 @@ import platform
import subprocess
import sys
import tarfile
from typing import Any, Dict, Optional
from typing import Dict, Optional
from urllib.request import urlopen
import click
@ -15,6 +15,7 @@ from .. import fmt
from .. import env as tutor_env
from .. import exceptions
from .. import serialize
from ..types import Config
from .context import Context
@ -135,7 +136,7 @@ def load_config(root: str) -> Dict[str, Optional[str]]:
return config
def save_webui_config_file(root: str, config: Dict[str, Any]) -> None:
def save_webui_config_file(root: str, config: Config) -> None:
path = config_path(root)
directory = os.path.dirname(path)
if not os.path.exists(directory):

View File

@ -1,15 +1,11 @@
import os
from typing import cast, Dict, Any, Tuple
from typing import Tuple
from . import exceptions
from . import env
from . import fmt
from . import plugins
from . import serialize
from . import utils
from . import env, exceptions, fmt, plugins, serialize, utils
from .types import Config, cast_config
def update(root: str) -> Dict[str, Any]:
def update(root: str) -> Config:
"""
Load and save the configuration.
"""
@ -19,7 +15,7 @@ def update(root: str) -> Dict[str, Any]:
return config
def load(root: str) -> Dict[str, Any]:
def load(root: str) -> Config:
"""
Load full configuration. This will raise an exception if there is no current
configuration in the project root.
@ -28,13 +24,13 @@ def load(root: str) -> Dict[str, Any]:
return load_no_check(root)
def load_no_check(root: str) -> Dict[str, Any]:
def load_no_check(root: str) -> Config:
config, defaults = load_all(root)
merge(config, defaults)
return config
def load_all(root: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def load_all(root: str) -> Tuple[Config, Config]:
"""
Return:
current (dict): params currently saved in config.yml
@ -46,9 +42,7 @@ def load_all(root: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
return current, defaults
def merge(
config: Dict[str, str], defaults: Dict[str, str], force: bool = False
) -> None:
def merge(config: Config, defaults: Config, force: bool = False) -> None:
"""
Merge default values with user configuration and perform rendering of "{{...}}"
values.
@ -58,23 +52,18 @@ def merge(
config[key] = env.render_unknown(config, value)
def load_defaults() -> Dict[str, Any]:
def load_defaults() -> Config:
config = serialize.load(env.read_template_file("config.yml"))
return cast(Dict[str, Any], config)
return cast_config(config)
def load_config_file(path: str) -> Dict[str, Any]:
def load_config_file(path: str) -> Config:
with open(path) as f:
config = serialize.load(f.read())
if not isinstance(config, dict):
raise exceptions.TutorError(
"Invalid configuration: expected dict, got {}".format(config.__class__)
)
return config
return cast_config(config)
def load_current(root: str, defaults: Dict[str, str]) -> Dict[str, Any]:
def load_current(root: str, defaults: Config) -> Config:
"""
Load the configuration currently stored on disk.
Note: this modifies the defaults with the plugin default values.
@ -87,7 +76,7 @@ def load_current(root: str, defaults: Dict[str, str]) -> Dict[str, Any]:
return config
def load_user(root: str) -> Dict[str, Any]:
def load_user(root: str) -> Config:
path = config_path(root)
if not os.path.exists(path):
return {}
@ -97,14 +86,14 @@ def load_user(root: str) -> Dict[str, Any]:
return config
def load_env(config: Dict[str, str], defaults: Dict[str, str]) -> None:
def load_env(config: Config, defaults: Config) -> None:
for k in defaults.keys():
env_var = "TUTOR_" + k
if env_var in os.environ:
config[k] = serialize.parse(os.environ[env_var])
def load_required(config: Dict[str, str], defaults: Dict[str, str]) -> None:
def load_required(config: Config, defaults: Config) -> None:
"""
All these keys must be present in the user's config.yml. This includes all values
that are generated once and must be kept after that, such as passwords.
@ -121,7 +110,7 @@ def load_required(config: Dict[str, str], defaults: Dict[str, str]) -> None:
config[key] = env.render_unknown(config, defaults[key])
def load_plugins(config: Dict[str, str], defaults: Dict[str, str]) -> None:
def load_plugins(config: Config, defaults: Config) -> None:
"""
Add, override and set new defaults from plugins.
"""
@ -143,11 +132,11 @@ def load_plugins(config: Dict[str, str], defaults: Dict[str, str]) -> None:
config[key] = env.render_unknown(config, value)
def is_service_activated(config: Dict[str, Any], service: str) -> bool:
def is_service_activated(config: Config, service: str) -> bool:
return config["RUN_" + service.upper()] is not False
def upgrade_obsolete(config: Dict[str, Any]) -> None:
def upgrade_obsolete(config: Config) -> None:
# Openedx-specific mysql passwords
if "MYSQL_PASSWORD" in config:
config["MYSQL_ROOT_PASSWORD"] = config["MYSQL_PASSWORD"]
@ -209,7 +198,7 @@ def convert_json2yml(root: str) -> None:
)
def save_config_file(root: str, config: Dict[str, str]) -> None:
def save_config_file(root: str, config: Config) -> None:
path = config_path(root)
utils.ensure_file_directory_exists(path)
with open(path, "w") as of:

View File

@ -1,17 +1,14 @@
import codecs
from copy import deepcopy
import os
from typing import Dict, Any, Iterable, List, Optional, Type, Union
from copy import deepcopy
from typing import Any, Iterable, List, Optional, Type, Union
import jinja2
import pkg_resources
from . import exceptions
from . import fmt
from . import plugins
from . import utils
from . import exceptions, fmt, plugins, utils
from .__about__ import __version__
from .types import Config
TEMPLATES_ROOT = pkg_resources.resource_filename("tutor", "templates")
VERSION_FILENAME = "version"
@ -20,7 +17,7 @@ BIN_FILE_EXTENSIONS = [".ico", ".jpg", ".png", ".ttf", ".woff", ".woff2"]
class Renderer:
@classmethod
def instance(cls: Type["Renderer"], config: Dict[str, Any]) -> "Renderer":
def instance(cls: Type["Renderer"], config: Config) -> "Renderer":
# Load template roots: these are required to be able to use
# {% include .. %} directives
template_roots = [TEMPLATES_ROOT]
@ -32,7 +29,7 @@ class Renderer:
def __init__(
self,
config: Dict[str, Any],
config: Config,
template_roots: List[str],
ignore_folders: Optional[List[str]] = None,
):
@ -64,7 +61,10 @@ class Renderer:
The elements of `prefix` must contain only "/", and not os.sep.
"""
full_prefix = "/".join(prefix)
for template in self.environment.loader.list_templates(): # type: ignore
env_templates: List[
str
] = self.environment.loader.list_templates() # type:ignore[no-untyped-call]
for template in env_templates:
if template.startswith(full_prefix) and self.is_part_of_env(template):
yield template
@ -171,7 +171,7 @@ class Renderer:
)
def save(root: str, config: Dict[str, Any]) -> None:
def save(root: str, config: Config) -> None:
"""
Save the full environment, including version information.
"""
@ -206,7 +206,7 @@ def upgrade_obsolete(root: str) -> None:
def save_plugin_templates(
plugin: plugins.BasePlugin, root: str, config: Dict[str, Any]
plugin: plugins.BasePlugin, root: str, config: Config
) -> None:
"""
Save plugin templates to plugins/<plugin name>/*.
@ -218,7 +218,7 @@ def save_plugin_templates(
save_all_from(subdir_path, plugins_root, config)
def save_all_from(prefix: str, root: str, config: Dict[str, Any]) -> None:
def save_all_from(prefix: str, root: str, config: Config) -> None:
"""
Render the templates that start with `prefix` and store them with the same
hierarchy at `root`. Here, `prefix` can be the result of os.path.join(...).
@ -240,7 +240,7 @@ def write_to(content: Union[str, bytes], path: str) -> None:
of_text.write(content)
def render_file(config: Dict[str, Any], *path: str) -> Union[str, bytes]:
def render_file(config: Config, *path: str) -> Union[str, bytes]:
"""
Return the rendered contents of a template.
"""
@ -249,7 +249,7 @@ def render_file(config: Dict[str, Any], *path: str) -> Union[str, bytes]:
return renderer.render_template(template_name)
def render_dict(config: Dict[str, Any]) -> None:
def render_dict(config: Config) -> None:
"""
Render the values from the dict. This is useful for rendering the default
values from config.yml.
@ -257,7 +257,7 @@ def render_dict(config: Dict[str, Any]) -> None:
Args:
config (dict)
"""
rendered = {}
rendered: Config = {}
for key, value in config.items():
if isinstance(value, str):
rendered[key] = render_str(config, value)
@ -267,13 +267,13 @@ def render_dict(config: Dict[str, Any]) -> None:
config[k] = v
def render_unknown(config: Dict[str, Any], value: Any) -> Any:
def render_unknown(config: Config, value: Any) -> Any:
if isinstance(value, str):
return render_str(config, value)
return value
def render_str(config: Dict[str, Any], text: str) -> str:
def render_str(config: Config, text: str) -> str:
"""
Args:
text (str)

View File

@ -1,11 +1,10 @@
from typing import Any, Dict
from . import fmt
from . import utils
from . import fmt, utils
from .types import Config, get_typed
def get_tag(config: Dict[str, Any], name: str) -> Any:
return config["DOCKER_IMAGE_" + name.upper().replace("-", "_")]
def get_tag(config: Config, name: str) -> str:
key = "DOCKER_IMAGE_" + name.upper().replace("-", "_")
return get_typed(config, key, str)
def build(path: str, tag: str, *args: str) -> None:

View File

@ -1,14 +1,14 @@
from typing import Any, Dict, List, Tuple
from typing import List, Tuple
import click
from . import config as tutor_config
from . import env
from . import exceptions
from . import fmt
from . import env, exceptions, fmt
from .__about__ import __version__
from .types import Config, get_typed
def update(root: str, interactive: bool = True) -> Dict[str, Any]:
def update(root: str, interactive: bool = True) -> Config:
"""
Load and save the configuration.
"""
@ -18,9 +18,7 @@ def update(root: str, interactive: bool = True) -> Dict[str, Any]:
return config
def load_all(
root: str, interactive: bool = True
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
def load_all(root: str, interactive: bool = True) -> Tuple[Config, Config]:
"""
Load configuration and interactively ask questions to collect param values from the user.
"""
@ -30,7 +28,7 @@ def load_all(
return config, defaults
def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None:
def ask_questions(config: Config, defaults: Config) -> None:
run_for_prod = config.get("LMS_HOST") != "local.overhang.io"
run_for_prod = click.confirm(
fmt.question(
@ -40,7 +38,7 @@ def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None:
default=run_for_prod,
)
if not run_for_prod:
dev_values = {
dev_values: Config = {
"LMS_HOST": "local.overhang.io",
"CMS_HOST": "studio.local.overhang.io",
"ENABLE_HTTPS": False,
@ -54,7 +52,8 @@ def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None:
if run_for_prod:
ask("Your website domain name for students (LMS)", "LMS_HOST", config, defaults)
if "localhost" in config["LMS_HOST"]:
lms_host = get_typed(config, "LMS_HOST", str)
if "localhost" in lms_host:
raise exceptions.TutorError(
"You may not use 'localhost' as the LMS domain name. To run a local platform for testing purposes you should answer 'n' to the previous question."
)
@ -159,19 +158,18 @@ def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None:
)
def ask(
question: str, key: str, config: Dict[str, Any], defaults: Dict[str, Any]
) -> None:
default = env.render_str(config, config.get(key, defaults[key]))
def ask(question: str, key: str, config: Config, defaults: Config) -> None:
default = get_typed(defaults, key, str)
default = get_typed(config, key, str, default=default)
default = env.render_str(config, default)
config[key] = click.prompt(
fmt.question(question), prompt_suffix=" ", default=default, show_default=True
)
def ask_bool(
question: str, key: str, config: Dict[str, Any], defaults: Dict[str, Any]
) -> None:
default = config.get(key, defaults[key])
def ask_bool(question: str, key: str, config: Config, defaults: Config) -> None:
default = get_typed(defaults, key, bool)
default = get_typed(config, key, bool, default=default)
config[key] = click.confirm(
fmt.question(question), prompt_suffix=" ", default=default
)
@ -180,11 +178,11 @@ def ask_bool(
def ask_choice(
question: str,
key: str,
config: Dict[str, Any],
defaults: Dict[str, Any],
config: Config,
defaults: Config,
choices: List[str],
) -> None:
default = config.get(key, defaults[key])
default = str(config.get(key, defaults[key]))
answer = click.prompt(
fmt.question(question),
type=click.Choice(choices),

View File

@ -1,8 +1,7 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Dict, Iterator, List, Optional, Tuple, Union
from . import env
from . import fmt
from . import plugins
from . import env, fmt, plugins
from .types import Config
BASE_OPENEDX_COMMAND = """
export DJANGO_SETTINGS_MODULE=$SERVICE_VARIANT.envs.$SETTINGS
@ -11,7 +10,7 @@ echo "Loading settings $DJANGO_SETTINGS_MODULE"
class BaseJobRunner:
def __init__(self, root: str, config: Dict[str, Any]):
def __init__(self, root: str, config: Config):
self.root = root
self.config = config

View File

@ -1,18 +1,15 @@
from collections import namedtuple
from copy import deepcopy
from glob import glob
import importlib
import os
from copy import deepcopy
from glob import glob
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
import appdirs
import click
import pkg_resources
from . import exceptions
from . import fmt
from . import serialize
from . import exceptions, fmt, serialize
from .types import Config, get_typed
CONFIG_KEY = "PLUGINS"
@ -69,7 +66,7 @@ class BasePlugin:
self.command: click.Command = command
@staticmethod
def load_config(obj: Any, plugin_name: str) -> Dict[str, Dict[str, Any]]:
def load_config(obj: Any, plugin_name: str) -> Dict[str, Config]:
"""
Load config and check types.
"""
@ -181,15 +178,15 @@ class BasePlugin:
return self.name.upper() + "_" + key
@property
def config_add(self) -> Dict[str, Any]:
def config_add(self) -> Config:
return self.config.get("add", {})
@property
def config_set(self) -> Dict[str, Any]:
def config_set(self) -> Config:
return self.config.get("set", {})
@property
def config_defaults(self) -> Dict[str, Any]:
def config_defaults(self) -> Config:
return self.config.get("defaults", {})
@property
@ -269,16 +266,32 @@ class DictPlugin(BasePlugin):
os.environ.get(ROOT_ENV_VAR_NAME, "")
) or appdirs.user_data_dir(appname="tutor-plugins")
def __init__(self, data: Dict[str, Any]):
Module = namedtuple("Module", data.keys()) # type: ignore
obj = Module(**data) # type: ignore
super().__init__(data["name"], obj)
self._version = data["version"]
def __init__(self, data: Config):
name = data["name"]
if not isinstance(name, str):
raise exceptions.TutorError(
"Invalid plugin name: '{}'. Expected str, got {}".format(
name, name.__class__
)
)
# Create a generic object (sort of a named tuple) which will contain all key/values from data
class Module:
pass
obj = Module()
for key, value in data.items():
setattr(obj, key, value)
super().__init__(name, obj)
version = data["version"]
if not isinstance(version, str):
raise TypeError("DictPlugin.__version__ must be str")
self._version: str = version
@property
def version(self) -> str:
if not isinstance(self._version, str):
raise TypeError("DictPlugin.__version__ must be str")
return self._version
@classmethod
@ -305,7 +318,7 @@ class Plugins:
DictPlugin,
]
def __init__(self, config: Dict[str, Any]):
def __init__(self, config: Config):
self.config = deepcopy(config)
# patches has the following structure:
# {patch_name -> {plugin_name -> "content"}}
@ -380,18 +393,17 @@ def iter_installed() -> Iterator[BasePlugin]:
yield from Plugins.iter_installed()
def enable(config: Dict[str, Any], name: str) -> None:
def enable(config: Config, name: str) -> None:
if not is_installed(name):
raise exceptions.TutorError("plugin '{}' is not installed.".format(name))
if is_enabled(config, name):
return
if CONFIG_KEY not in config:
config[CONFIG_KEY] = []
config[CONFIG_KEY].append(name)
config[CONFIG_KEY].sort()
enabled = enabled_plugins(config)
enabled.append(name)
enabled.sort()
def disable(config: Dict[str, Any], name: str) -> None:
def disable(config: Config, name: str) -> None:
fmt.echo_info("Disabling plugin {}...".format(name))
for plugin in Plugins(config).iter_enabled():
if name == plugin.name:
@ -400,25 +412,32 @@ def disable(config: Dict[str, Any], name: str) -> None:
config.pop(key, None)
fmt.echo_info(" Removed config entry {}={}".format(key, value))
# Remove plugin from list
while name in config[CONFIG_KEY]:
config[CONFIG_KEY].remove(name)
enabled = enabled_plugins(config)
while name in enabled:
enabled.remove(name)
fmt.echo_info(" Plugin disabled")
def iter_enabled(config: Dict[str, Any]) -> Iterator[BasePlugin]:
def iter_enabled(config: Config) -> Iterator[BasePlugin]:
yield from Plugins(config).iter_enabled()
def is_enabled(config: Dict[str, Any], name: str) -> bool:
plugin_list = config.get(CONFIG_KEY) or []
return name in plugin_list
def is_enabled(config: Config, name: str) -> bool:
return name in enabled_plugins(config)
def iter_patches(config: Dict[str, str], name: str) -> Iterator[Tuple[str, str]]:
def enabled_plugins(config: Config) -> List[str]:
if not config.get(CONFIG_KEY):
config[CONFIG_KEY] = []
plugins = get_typed(config, CONFIG_KEY, list)
return plugins
def iter_patches(config: Config, name: str) -> Iterator[Tuple[str, str]]:
yield from Plugins(config).iter_patches(name)
def iter_hooks(
config: Dict[str, Any], hook_name: str
config: Config, hook_name: str
) -> Iterator[Tuple[str, Union[Dict[str, str], List[str]]]]:
yield from Plugins(config).iter_hooks(hook_name)

View File

@ -1,13 +1,12 @@
import re
from typing import Any, IO, Iterator, Tuple, Union
from typing import IO, Any, Iterator, Tuple, Union
import click
import yaml
from _io import TextIOWrapper
from yaml.parser import ParserError
from yaml.scanner import ScannerError
import click
def load(stream: Union[str, IO[str]]) -> Any:
return yaml.load(stream, Loader=yaml.SafeLoader)
@ -21,6 +20,12 @@ def dump(content: Any, fileobj: TextIOWrapper) -> None:
yaml.dump(content, stream=fileobj, default_flow_style=False)
def dumps(content: Any) -> str:
result = yaml.dump(content, default_flow_style=False)
assert isinstance(result, str)
return result
def parse(v: Union[str, IO[str]]) -> Any:
"""
Parse a yaml-formatted string.

39
tutor/types.py Normal file
View File

@ -0,0 +1,39 @@
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from . import exceptions
ConfigValue = Union[
str, float, None, bool, List[str], List[Any], Dict[str, Any], Dict[Any, Any]
]
Config = Dict[str, ConfigValue]
def cast_config(config: Any) -> Config:
if not isinstance(config, dict):
raise exceptions.TutorError(
"Invalid configuration: expected dict, got {}".format(config.__class__)
)
for key in config.keys():
if not isinstance(key, str):
raise exceptions.TutorError(
"Invalid configuration: expected str, got {} for key '{}'".format(
key.__class__, key
)
)
return config
T = TypeVar("T")
def get_typed(
config: Config, key: str, expected_type: Type[T], default: Optional[T] = None
) -> T:
value = config.get(key, default)
if not isinstance(value, expected_type):
raise exceptions.TutorError(
"Invalid config entry: expected {}, got {} for key '{}'".format(
expected_type.__name__, value.__class__, key
)
)
return value