6
0
mirror of https://github.com/ChristianLight/tutor.git synced 2024-12-12 22:27:47 +00:00

refactor: better config type checking

I stumbled upon a bug that should have been detected by the type
checking. Turns out, considering that config is of type Dict[str, Any]
means that we can use just any method on all config values -- which is
terrible. I discovered this after I set `config["PLUGINS"] = None`:
this triggered a crash when I enabled a plugin.
We resolve this by making the Config type more explicit. We also take
the opportunity to remove a few cast statements.
This commit is contained in:
Régis Behmo 2021-04-06 12:09:00 +02:00 committed by Régis Behmo
parent 887ba31e09
commit 336cb79fa8
22 changed files with 269 additions and 201 deletions

View File

@ -1,10 +1,10 @@
from typing import Any, Dict
import unittest import unittest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import tempfile import tempfile
from tutor import config as tutor_config from tutor import config as tutor_config
from tutor import interactive from tutor import interactive
from tutor.types import get_typed, Config
class ConfigTests(unittest.TestCase): class ConfigTests(unittest.TestCase):
@ -13,13 +13,13 @@ class ConfigTests(unittest.TestCase):
self.assertNotIn("TUTOR_VERSION", defaults) self.assertNotIn("TUTOR_VERSION", defaults)
def test_merge(self) -> None: def test_merge(self) -> None:
config1 = {"x": "y"} config1: Config = {"x": "y"}
config2 = {"x": "z"} config2: Config = {"x": "z"}
tutor_config.merge(config1, config2) tutor_config.merge(config1, config2)
self.assertEqual({"x": "y"}, config1) self.assertEqual({"x": "y"}, config1)
def test_merge_render(self) -> None: def test_merge_render(self) -> None:
config: Dict[str, Any] = {} config: Config = {}
defaults = tutor_config.load_defaults() defaults = tutor_config.load_defaults()
with patch.object(tutor_config.utils, "random_string", return_value="abcd"): with patch.object(tutor_config.utils, "random_string", return_value="abcd"):
tutor_config.merge(config, defaults) tutor_config.merge(config, defaults)
@ -62,13 +62,13 @@ class ConfigTests(unittest.TestCase):
config, defaults = interactive.load_all(rootdir, interactive=False) config, defaults = interactive.load_all(rootdir, interactive=False)
self.assertIn("MYSQL_ROOT_PASSWORD", config) self.assertIn("MYSQL_ROOT_PASSWORD", config)
self.assertEqual(8, len(config["MYSQL_ROOT_PASSWORD"])) self.assertEqual(8, len(get_typed(config, "MYSQL_ROOT_PASSWORD", str)))
self.assertNotIn("LMS_HOST", config) self.assertNotIn("LMS_HOST", config)
self.assertEqual("www.myopenedx.com", defaults["LMS_HOST"]) self.assertEqual("www.myopenedx.com", defaults["LMS_HOST"])
self.assertEqual("studio.{{ LMS_HOST }}", defaults["CMS_HOST"]) self.assertEqual("studio.{{ LMS_HOST }}", defaults["CMS_HOST"])
def test_is_service_activated(self) -> None: def test_is_service_activated(self) -> None:
config = {"RUN_SERVICE1": True, "RUN_SERVICE2": False} config: Config = {"RUN_SERVICE1": True, "RUN_SERVICE2": False}
self.assertTrue(tutor_config.is_service_activated(config, "service1")) self.assertTrue(tutor_config.is_service_activated(config, "service1"))
self.assertFalse(tutor_config.is_service_activated(config, "service2")) self.assertFalse(tutor_config.is_service_activated(config, "service2"))

View File

@ -1,6 +1,5 @@
import os import os
import tempfile import tempfile
from typing import Any, Dict
import unittest import unittest
from unittest.mock import patch, Mock from unittest.mock import patch, Mock
@ -8,6 +7,7 @@ from tutor import config as tutor_config
from tutor import env from tutor import env
from tutor import fmt from tutor import fmt
from tutor import exceptions from tutor import exceptions
from tutor.types import Config
class EnvTests(unittest.TestCase): class EnvTests(unittest.TestCase):
@ -55,7 +55,7 @@ class EnvTests(unittest.TestCase):
self.assertRaises(exceptions.TutorError, env.render_str, {}, "hello {{ name }}") self.assertRaises(exceptions.TutorError, env.render_str, {}, "hello {{ name }}")
def test_render_file(self) -> None: def test_render_file(self) -> None:
config: Dict[str, Any] = {} config: Config = {}
tutor_config.merge(config, tutor_config.load_defaults()) tutor_config.merge(config, tutor_config.load_defaults())
config["MYSQL_ROOT_PASSWORD"] = "testpassword" config["MYSQL_ROOT_PASSWORD"] = "testpassword"
rendered = env.render_file(config, "hooks", "mysql", "init") rendered = env.render_file(config, "hooks", "mysql", "init")
@ -125,7 +125,7 @@ class EnvTests(unittest.TestCase):
f.write("Hello my ID is {{ ID }}") f.write("Hello my ID is {{ ID }}")
# Create configuration # Create configuration
config = {"ID": "abcd"} config: Config = {"ID": "abcd"}
# Render templates # Render templates
with patch.object( with patch.object(
@ -162,7 +162,7 @@ class EnvTests(unittest.TestCase):
f.write("some content") f.write("some content")
# Load env once # Load env once
config: Dict[str, Any] = {"PLUGINS": []} config: Config = {"PLUGINS": []}
env1 = env.Renderer.instance(config).environment env1 = env.Renderer.instance(config).environment
with patch.object( with patch.object(
@ -171,7 +171,7 @@ class EnvTests(unittest.TestCase):
return_value=[plugin1], return_value=[plugin1],
): ):
# Load env a second time # Load env a second time
config["PLUGINS"].append("myplugin") config["PLUGINS"] = ["myplugin"]
env2 = env.Renderer.instance(config).environment env2 = env.Renderer.instance(config).environment
self.assertNotIn("plugin1/myplugin.txt", env1.loader.list_templates()) # type: ignore self.assertNotIn("plugin1/myplugin.txt", env1.loader.list_templates()) # type: ignore

View File

@ -1,10 +1,11 @@
import unittest import unittest
from tutor import images from tutor import images
from tutor.types import Config
class ImagesTests(unittest.TestCase): class ImagesTests(unittest.TestCase):
def test_get_tag(self) -> None: def test_get_tag(self) -> None:
config = { config: Config = {
"DOCKER_IMAGE_OPENEDX": "registry/openedx", "DOCKER_IMAGE_OPENEDX": "registry/openedx",
"DOCKER_IMAGE_OPENEDX_DEV": "registry/openedxdev", "DOCKER_IMAGE_OPENEDX_DEV": "registry/openedxdev",
} }

View File

@ -1,4 +1,3 @@
from typing import Any, Dict
import unittest import unittest
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -6,6 +5,7 @@ from tutor import config as tutor_config
from tutor import exceptions from tutor import exceptions
from tutor import fmt from tutor import fmt
from tutor import plugins from tutor import plugins
from tutor.types import get_typed, Config
class PluginsTests(unittest.TestCase): class PluginsTests(unittest.TestCase):
@ -37,26 +37,26 @@ class PluginsTests(unittest.TestCase):
) )
def test_enable(self) -> None: def test_enable(self) -> None:
config: Dict[str, Any] = {plugins.CONFIG_KEY: []} config: Config = {plugins.CONFIG_KEY: []}
with patch.object(plugins, "is_installed", return_value=True): with patch.object(plugins, "is_installed", return_value=True):
plugins.enable(config, "plugin2") plugins.enable(config, "plugin2")
plugins.enable(config, "plugin1") plugins.enable(config, "plugin1")
self.assertEqual(["plugin1", "plugin2"], config[plugins.CONFIG_KEY]) self.assertEqual(["plugin1", "plugin2"], config[plugins.CONFIG_KEY])
def test_enable_twice(self) -> None: def test_enable_twice(self) -> None:
config: Dict[str, Any] = {plugins.CONFIG_KEY: []} config: Config = {plugins.CONFIG_KEY: []}
with patch.object(plugins, "is_installed", return_value=True): with patch.object(plugins, "is_installed", return_value=True):
plugins.enable(config, "plugin1") plugins.enable(config, "plugin1")
plugins.enable(config, "plugin1") plugins.enable(config, "plugin1")
self.assertEqual(["plugin1"], config[plugins.CONFIG_KEY]) self.assertEqual(["plugin1"], config[plugins.CONFIG_KEY])
def test_enable_not_installed_plugin(self) -> None: def test_enable_not_installed_plugin(self) -> None:
config: Dict[str, Any] = {"PLUGINS": []} config: Config = {"PLUGINS": []}
with patch.object(plugins, "is_installed", return_value=False): with patch.object(plugins, "is_installed", return_value=False):
self.assertRaises(exceptions.TutorError, plugins.enable, config, "plugin1") self.assertRaises(exceptions.TutorError, plugins.enable, config, "plugin1")
def test_disable(self) -> None: def test_disable(self) -> None:
config: Dict[str, Any] = {"PLUGINS": ["plugin1", "plugin2"]} config: Config = {"PLUGINS": ["plugin1", "plugin2"]}
with patch.object(fmt, "STDOUT"): with patch.object(fmt, "STDOUT"):
plugins.disable(config, "plugin1") plugins.disable(config, "plugin1")
self.assertEqual(["plugin2"], config["PLUGINS"]) self.assertEqual(["plugin2"], config["PLUGINS"])
@ -75,14 +75,14 @@ class PluginsTests(unittest.TestCase):
) )
], ],
): ):
config = {"PLUGINS": ["plugin1"], "KEY": "value"} config: Config = {"PLUGINS": ["plugin1"], "KEY": "value"}
with patch.object(fmt, "STDOUT"): with patch.object(fmt, "STDOUT"):
plugins.disable(config, "plugin1") plugins.disable(config, "plugin1")
self.assertEqual([], config["PLUGINS"]) self.assertEqual([], config["PLUGINS"])
self.assertNotIn("KEY", config) self.assertNotIn("KEY", config)
def test_none_plugins(self) -> None: def test_none_plugins(self) -> None:
config = {plugins.CONFIG_KEY: None} config: Config = {plugins.CONFIG_KEY: None}
self.assertFalse(plugins.is_enabled(config, "myplugin")) self.assertFalse(plugins.is_enabled(config, "myplugin"))
def test_patches(self) -> None: def test_patches(self) -> None:
@ -107,11 +107,11 @@ class PluginsTests(unittest.TestCase):
self.assertEqual([], patches) self.assertEqual([], patches)
def test_configure(self) -> None: def test_configure(self) -> None:
config = {"ID": "id"} config: Config = {"ID": "id"}
defaults: Dict[str, Any] = {} defaults: Config = {}
class plugin1: class plugin1:
config = { config: Config = {
"add": {"PARAM1": "value1", "PARAM2": "value2"}, "add": {"PARAM1": "value1", "PARAM2": "value2"},
"set": {"PARAM3": "value3"}, "set": {"PARAM3": "value3"},
"defaults": {"PARAM4": "value4"}, "defaults": {"PARAM4": "value4"},
@ -136,10 +136,10 @@ class PluginsTests(unittest.TestCase):
self.assertEqual({"PLUGIN1_PARAM4": "value4"}, defaults) self.assertEqual({"PLUGIN1_PARAM4": "value4"}, defaults)
def test_configure_set_does_not_override(self) -> None: def test_configure_set_does_not_override(self) -> None:
config = {"ID": "oldid"} config: Config = {"ID": "oldid"}
class plugin1: class plugin1:
config = {"set": {"ID": "newid"}} config: Config = {"set": {"ID": "newid"}}
with patch.object( with patch.object(
plugins.Plugins, plugins.Plugins,
@ -151,10 +151,10 @@ class PluginsTests(unittest.TestCase):
self.assertEqual({"ID": "oldid"}, config) self.assertEqual({"ID": "oldid"}, config)
def test_configure_set_random_string(self) -> None: def test_configure_set_random_string(self) -> None:
config: Dict[str, Any] = {} config: Config = {}
class plugin1: class plugin1:
config = {"set": {"PARAM1": "{{ 128|random_string }}"}} config: Config = {"set": {"PARAM1": "{{ 128|random_string }}"}}
with patch.object( with patch.object(
plugins.Plugins, plugins.Plugins,
@ -162,14 +162,14 @@ class PluginsTests(unittest.TestCase):
return_value=[plugins.BasePlugin("plugin1", plugin1)], return_value=[plugins.BasePlugin("plugin1", plugin1)],
): ):
tutor_config.load_plugins(config, {}) tutor_config.load_plugins(config, {})
self.assertEqual(128, len(config["PARAM1"])) self.assertEqual(128, len(get_typed(config, "PARAM1", str)))
def test_configure_default_value_with_previous_definition(self) -> None: def test_configure_default_value_with_previous_definition(self) -> None:
config: Dict[str, Any] = {} config: Config = {}
defaults = {"PARAM1": "value"} defaults: Config = {"PARAM1": "value"}
class plugin1: class plugin1:
config = {"defaults": {"PARAM2": "{{ PARAM1 }}"}} config: Config = {"defaults": {"PARAM2": "{{ PARAM1 }}"}}
with patch.object( with patch.object(
plugins.Plugins, plugins.Plugins,
@ -180,10 +180,10 @@ class PluginsTests(unittest.TestCase):
self.assertEqual("{{ PARAM1 }}", defaults["PLUGIN1_PARAM2"]) self.assertEqual("{{ PARAM1 }}", defaults["PLUGIN1_PARAM2"])
def test_configure_add_twice(self) -> None: def test_configure_add_twice(self) -> None:
config: Dict[str, Any] = {} config: Config = {}
class plugin1: class plugin1:
config = {"add": {"PARAM1": "{{ 10|random_string }}"}} config: Config = {"add": {"PARAM1": "{{ 10|random_string }}"}}
with patch.object( with patch.object(
plugins.Plugins, plugins.Plugins,
@ -191,14 +191,14 @@ class PluginsTests(unittest.TestCase):
return_value=[plugins.BasePlugin("plugin1", plugin1)], return_value=[plugins.BasePlugin("plugin1", plugin1)],
): ):
tutor_config.load_plugins(config, {}) tutor_config.load_plugins(config, {})
value1 = config["PLUGIN1_PARAM1"] value1 = get_typed(config, "PLUGIN1_PARAM1", str)
with patch.object( with patch.object(
plugins.Plugins, plugins.Plugins,
"iter_enabled", "iter_enabled",
return_value=[plugins.BasePlugin("plugin1", plugin1)], return_value=[plugins.BasePlugin("plugin1", plugin1)],
): ):
tutor_config.load_plugins(config, {}) tutor_config.load_plugins(config, {})
value2 = config["PLUGIN1_PARAM1"] value2 = get_typed(config, "PLUGIN1_PARAM1", str)
self.assertEqual(10, len(value1)) self.assertEqual(10, len(value1))
self.assertEqual(10, len(value2)) self.assertEqual(10, len(value2))
@ -218,10 +218,10 @@ class PluginsTests(unittest.TestCase):
) )
def test_plugins_are_updated_on_config_change(self) -> None: def test_plugins_are_updated_on_config_change(self) -> None:
config: Dict[str, Any] = {"PLUGINS": []} config: Config = {"PLUGINS": []}
plugins1 = plugins.Plugins(config) plugins1 = plugins.Plugins(config)
self.assertEqual(0, len(list(plugins1.iter_enabled()))) self.assertEqual(0, len(list(plugins1.iter_enabled())))
config["PLUGINS"].append("plugin1") config["PLUGINS"] = ["plugin1"]
with patch.object( with patch.object(
plugins.Plugins, plugins.Plugins,
"iter_installed", "iter_installed",

View File

@ -1,17 +1,18 @@
import os import os
from typing import Any, Callable, Dict, List, Tuple from typing import Callable, List, Tuple
import click import click
from mypy_extensions import VarArg from mypy_extensions import VarArg
from .exceptions import TutorError from .exceptions import TutorError
from .types import Config
from .utils import get_user_id from .utils import get_user_id
def create( def create(
root: str, root: str,
config: Dict[str, Any], config: Config,
docker_compose_func: Callable[[str, Dict[str, Any], VarArg(str)], int], docker_compose_func: Callable[[str, Config, VarArg(str)], int],
service: str, service: str,
path: str, path: str,
) -> str: ) -> str:

View File

@ -1,5 +1,3 @@
from typing import Dict
import click import click
from .compose import ComposeJobRunner from .compose import ComposeJobRunner
@ -7,6 +5,7 @@ from .local import docker_compose as local_docker_compose
from .. import config as tutor_config from .. import config as tutor_config
from .. import env as tutor_env from .. import env as tutor_env
from .. import fmt from .. import fmt
from ..types import Config
from .context import Context from .context import Context
@ -28,7 +27,7 @@ def build(context: Context, mode: str) -> None:
) )
def build_command(config: Dict[str, str], target: str) -> str: def build_command(config: Config, target: str) -> str:
gradle_target = { gradle_target = {
"debug": "assembleProdDebuggable", "debug": "assembleProdDebuggable",
"release": "assembleProdRelease", "release": "assembleProdRelease",

View File

@ -1,5 +1,5 @@
import os import os
from typing import Any, Callable, Dict, List from typing import Callable, List
import click import click
from mypy_extensions import VarArg from mypy_extensions import VarArg
@ -11,6 +11,7 @@ from ..exceptions import TutorError
from .. import fmt from .. import fmt
from .. import jobs from .. import jobs
from .. import serialize from .. import serialize
from ..types import Config
from .. import utils from .. import utils
from .context import Context from .context import Context
@ -19,8 +20,8 @@ class ComposeJobRunner(jobs.BaseJobRunner):
def __init__( def __init__(
self, self,
root: str, root: str,
config: Dict[str, Any], config: Config,
docker_compose_func: Callable[[str, Dict[str, Any], VarArg(str)], int], docker_compose_func: Callable[[str, Config, VarArg(str)], int],
): ):
super().__init__(root, config) super().__init__(root, config)
self.docker_compose_func = docker_compose_func self.docker_compose_func = docker_compose_func

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List from typing import List
import click import click
@ -8,6 +8,7 @@ from .. import exceptions
from .. import fmt from .. import fmt
from .. import interactive as interactive_config from .. import interactive as interactive_config
from .. import serialize from .. import serialize
from ..types import Config
from .context import Context from .context import Context
@ -40,7 +41,7 @@ def config_command() -> None:
) )
@click.pass_obj @click.pass_obj
def save( def save(
context: Context, interactive: bool, set_vars: Dict[str, Any], unset_vars: List[str] context: Context, interactive: bool, set_vars: Config, unset_vars: List[str]
) -> None: ) -> None:
config, defaults = interactive_config.load_all( config, defaults = interactive_config.load_all(
context.root, interactive=interactive context.root, interactive=interactive
@ -91,7 +92,7 @@ def printvalue(context: Context, key: str) -> None:
config = tutor_config.load(context.root) config = tutor_config.load(context.root)
try: try:
# Note that this will incorrectly print None values # Note that this will incorrectly print None values
fmt.echo(config[key]) fmt.echo(str(config[key]))
except KeyError as e: except KeyError as e:
raise exceptions.TutorError( raise exceptions.TutorError(
"Missing configuration value: {}".format(key) "Missing configuration value: {}".format(key)

View File

@ -1,9 +1,7 @@
from typing import Any, Dict from ..types import Config
def unimplemented_docker_compose( def unimplemented_docker_compose(root: str, config: Config, *command: str) -> int:
root: str, config: Dict[str, Any], *command: str
) -> int:
raise NotImplementedError raise NotImplementedError
@ -13,5 +11,5 @@ class Context:
self.root = root self.root = root
self.docker_compose_func = unimplemented_docker_compose self.docker_compose_func = unimplemented_docker_compose
def docker_compose(self, root: str, config: Dict[str, Any], *command: str) -> int: def docker_compose(self, root: str, config: Config, *command: str) -> int:
return self.docker_compose_func(root, config, *command) return self.docker_compose_func(root, config, *command)

View File

@ -1,17 +1,18 @@
import os import os
from typing import Any, Dict, List from typing import List
import click import click
from .. import config as tutor_config from .. import config as tutor_config
from .. import env as tutor_env from .. import env as tutor_env
from .. import fmt from .. import fmt
from ..types import Config
from .. import utils from .. import utils
from . import compose from . import compose
from .context import Context from .context import Context
def docker_compose(root: str, config: Dict[str, Any], *command: str) -> int: def docker_compose(root: str, config: Config, *command: str) -> int:
""" """
Run docker-compose with dev arguments. Run docker-compose with dev arguments.
""" """

View File

@ -1,11 +1,13 @@
from typing import cast, Any, Dict, Iterator, List, Tuple from typing import Iterator, List, Tuple
import click import click
from .. import config as tutor_config from .. import config as tutor_config
from .. import env as tutor_env from .. import env as tutor_env
from .. import exceptions
from .. import images from .. import images
from .. import plugins from .. import plugins
from ..types import Config
from .. import utils from .. import utils
from .context import Context from .context import Context
@ -105,7 +107,7 @@ def printtag(context: Context, image_names: List[str]) -> None:
print(tag) print(tag)
def build_image(root: str, config: Dict[str, Any], image: str, *args: str) -> None: def build_image(root: str, config: Config, image: str, *args: str) -> None:
# Build base images # Build base images
for img, tag in iter_images(config, image, BASE_IMAGE_NAMES): for img, tag in iter_images(config, image, BASE_IMAGE_NAMES):
images.build(tutor_env.pathjoin(root, "build", img), tag, *args) images.build(tutor_env.pathjoin(root, "build", img), tag, *args)
@ -122,14 +124,14 @@ def build_image(root: str, config: Dict[str, Any], image: str, *args: str) -> No
images.build(tutor_env.pathjoin(root, "build", img), tag, *dev_build_arg, *args) images.build(tutor_env.pathjoin(root, "build", img), tag, *dev_build_arg, *args)
def pull_image(config: Dict[str, Any], image: str) -> None: def pull_image(config: Config, image: str) -> None:
for _img, tag in iter_images(config, image, all_image_names(config)): for _img, tag in iter_images(config, image, all_image_names(config)):
images.pull(tag) images.pull(tag)
for _plugin, _img, tag in iter_plugin_images(config, image, "remote-image"): for _plugin, _img, tag in iter_plugin_images(config, image, "remote-image"):
images.pull(tag) images.pull(tag)
def push_image(config: Dict[str, Any], image: str) -> None: def push_image(config: Config, image: str) -> None:
for _img, tag in iter_images(config, image, BASE_IMAGE_NAMES): for _img, tag in iter_images(config, image, BASE_IMAGE_NAMES):
images.push(tag) images.push(tag)
for _plugin, _img, tag in iter_plugin_images(config, image, "remote-image"): for _plugin, _img, tag in iter_plugin_images(config, image, "remote-image"):
@ -137,7 +139,7 @@ def push_image(config: Dict[str, Any], image: str) -> None:
def iter_images( def iter_images(
config: Dict[str, Any], image: str, image_list: List[str] config: Config, image: str, image_list: List[str]
) -> Iterator[Tuple[str, str]]: ) -> Iterator[Tuple[str, str]]:
for img in image_list: for img in image_list:
if image in [img, "all"]: if image in [img, "all"]:
@ -146,21 +148,26 @@ def iter_images(
def iter_plugin_images( def iter_plugin_images(
config: Dict[str, Any], image: str, hook_name: str config: Config, image: str, hook_name: str
) -> Iterator[Tuple[str, str, str]]: ) -> Iterator[Tuple[str, str, str]]:
for plugin, hook in plugins.iter_hooks(config, hook_name): for plugin, hook in plugins.iter_hooks(config, hook_name):
hook = cast(Dict[str, str], hook) if not isinstance(hook, dict):
raise exceptions.TutorError(
"Invalid hook '{}': expected dict, got {}".format(
hook_name, hook.__class__
)
)
for img, tag in hook.items(): for img, tag in hook.items():
if image in [img, "all"]: if image in [img, "all"]:
tag = tutor_env.render_str(config, tag) tag = tutor_env.render_str(config, tag)
yield plugin, img, tag yield plugin, img, tag
def all_image_names(config: Dict[str, Any]) -> List[str]: def all_image_names(config: Config) -> List[str]:
return BASE_IMAGE_NAMES + vendor_image_names(config) return BASE_IMAGE_NAMES + vendor_image_names(config)
def vendor_image_names(config: Dict[str, Any]) -> List[str]: def vendor_image_names(config: Config) -> List[str]:
vendor_images = VENDOR_IMAGES[:] vendor_images = VENDOR_IMAGES[:]
for image in VENDOR_IMAGES: for image in VENDOR_IMAGES:
if not config.get("RUN_" + image.upper(), True): if not config.get("RUN_" + image.upper(), True):

View File

@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from time import sleep from time import sleep
from typing import cast, Any, Dict, List, Optional, Type from typing import Any, List, Optional, Type
import click import click
@ -11,6 +11,7 @@ from .. import fmt
from .. import interactive as interactive_config from .. import interactive as interactive_config
from .. import jobs from .. import jobs
from .. import serialize from .. import serialize
from ..types import Config, get_typed
from .. import utils from .. import utils
from .context import Context from .context import Context
@ -50,7 +51,11 @@ class K8sJobRunner(jobs.BaseJobRunner):
def load_job(self, name: str) -> Any: def load_job(self, name: str) -> Any:
all_jobs = self.render("k8s", "jobs.yml") all_jobs = self.render("k8s", "jobs.yml")
for job in serialize.load_all(all_jobs): for job in serialize.load_all(all_jobs):
job_name = cast(str, job["metadata"]["name"]) job_name = job["metadata"]["name"]
if not isinstance(job_name, str):
raise exceptions.TutorError(
"Invalid job name: '{}'. Expected str.".format(job_name)
)
if job_name == name: if job_name == name:
return job return job
raise ValueError("Could not find job '{}'".format(name)) raise ValueError("Could not find job '{}'".format(name))
@ -64,7 +69,7 @@ class K8sJobRunner(jobs.BaseJobRunner):
api = K8sClients.instance().batch_api api = K8sClients.instance().batch_api
return [ return [
job.metadata.name job.metadata.name
for job in api.list_namespaced_job(self.config["K8S_NAMESPACE"]).items for job in api.list_namespaced_job(k8s_namespace(self.config)).items
if job.status.active if job.status.active
] ]
@ -139,7 +144,7 @@ class K8sJobRunner(jobs.BaseJobRunner):
""" kubectl logs --namespace={namespace} --follow $(kubectl get --namespace={namespace} pods """ """ kubectl logs --namespace={namespace} --follow $(kubectl get --namespace={namespace} pods """
"""--selector=job-name={job_name} -o=jsonpath="{{.items[0].metadata.name}}")\n\n""" """--selector=job-name={job_name} -o=jsonpath="{{.items[0].metadata.name}}")\n\n"""
"Waiting for job completion..." "Waiting for job completion..."
).format(job_name=job_name, namespace=self.config["K8S_NAMESPACE"]) ).format(job_name=job_name, namespace=k8s_namespace(self.config))
fmt.echo_info(message) fmt.echo_info(message)
# Wait for completion # Wait for completion
@ -257,14 +262,15 @@ def reboot(context: click.Context) -> None:
context.invoke(start) context.invoke(start)
def resource_selector(config: Dict[str, str], *selectors: str) -> List[str]: def resource_selector(config: Config, *selectors: str) -> List[str]:
""" """
Convenient utility for filtering only the resources that belong to this project. Convenient utility for filtering only the resources that belong to this project.
""" """
selector = ",".join( selector = ",".join(
["app.kubernetes.io/instance=openedx-" + config["ID"]] + list(selectors) ["app.kubernetes.io/instance=openedx-" + get_typed(config, "ID", str)]
+ list(selectors)
) )
return ["--namespace", config["K8S_NAMESPACE"], "--selector=" + selector] return ["--namespace", k8s_namespace(config), "--selector=" + selector]
@click.command(help="Completely delete an existing platform") @click.command(help="Completely delete an existing platform")
@ -398,7 +404,7 @@ def upgrade(context: Context, from_version: str) -> None:
running_version = "koa" running_version = "koa"
def upgrade_from_ironwood(config: Dict[str, Any]) -> None: def upgrade_from_ironwood(config: Config) -> None:
if not config["RUN_MONGODB"]: if not config["RUN_MONGODB"]:
fmt.echo_info( fmt.echo_info(
"You are not running MongDB (RUN_MONGODB=false). It is your " "You are not running MongDB (RUN_MONGODB=false). It is your "
@ -425,7 +431,7 @@ your MongoDb cluster from v3.2 to v3.6. You should run something similar to:
fmt.echo_info(message) fmt.echo_info(message)
def upgrade_from_juniper(config: Dict[str, Any]) -> None: def upgrade_from_juniper(config: Config) -> None:
if not config["RUN_MYSQL"]: if not config["RUN_MYSQL"]:
fmt.echo_info( fmt.echo_info(
"You are not running MySQL (RUN_MYSQL=false). It is your " "You are not running MySQL (RUN_MYSQL=false). It is your "
@ -446,7 +452,7 @@ your MySQL database from v5.6 to v5.7. You should run something similar to:
def kubectl_exec( def kubectl_exec(
config: Dict[str, Any], service: str, command: str, attach: bool = False config: Config, service: str, command: str, attach: bool = False
) -> int: ) -> int:
selector = "app.kubernetes.io/name={}".format(service) selector = "app.kubernetes.io/name={}".format(service)
pods = K8sClients.instance().core_api.list_namespaced_pod( pods = K8sClients.instance().core_api.list_namespaced_pod(
@ -464,7 +470,7 @@ def kubectl_exec(
"exec", "exec",
*attach_opts, *attach_opts,
"--namespace", "--namespace",
config["K8S_NAMESPACE"], k8s_namespace(config),
pod_name, pod_name,
"--", "--",
"sh", "sh",
@ -474,7 +480,7 @@ def kubectl_exec(
) )
def wait_for_pod_ready(config: Dict[str, str], service: str) -> None: def wait_for_pod_ready(config: Config, service: str) -> None:
fmt.echo_info("Waiting for a {} pod to be ready...".format(service)) fmt.echo_info("Waiting for a {} pod to be ready...".format(service))
utils.kubectl( utils.kubectl(
"wait", "wait",
@ -485,6 +491,10 @@ def wait_for_pod_ready(config: Dict[str, str], service: str) -> None:
) )
def k8s_namespace(config: Config) -> str:
return get_typed(config, "K8S_NAMESPACE", str)
k8s.add_command(quickstart) k8s.add_command(quickstart)
k8s.add_command(start) k8s.add_command(start)
k8s.add_command(stop) k8s.add_command(stop)

View File

@ -1,18 +1,18 @@
import os import os
from typing import Dict, Any
import click import click
from .. import config as tutor_config from .. import config as tutor_config
from .. import env as tutor_env from .. import env as tutor_env
from .. import fmt from .. import fmt
from ..types import get_typed, Config
from .. import utils from .. import utils
from . import compose from . import compose
from .config import save as config_save_command from .config import save as config_save_command
from .context import Context from .context import Context
def docker_compose(root: str, config: Dict[str, Any], *command: str) -> int: def docker_compose(root: str, config: Config, *command: str) -> int:
""" """
Run docker-compose with local and production yml files. Run docker-compose with local and production yml files.
""" """
@ -27,7 +27,7 @@ def docker_compose(root: str, config: Dict[str, Any], *command: str) -> int:
tutor_env.pathjoin(root, "local", "docker-compose.prod.yml"), tutor_env.pathjoin(root, "local", "docker-compose.prod.yml"),
*args, *args,
"--project-name", "--project-name",
config["LOCAL_PROJECT_NAME"], get_typed(config, "LOCAL_PROJECT_NAME", str),
*command *command
) )
@ -118,7 +118,7 @@ Are you sure you want to continue?"""
running_version = "koa" running_version = "koa"
def upgrade_from_ironwood(context: click.Context, config: Dict[str, Any]) -> None: def upgrade_from_ironwood(context: click.Context, config: Config) -> None:
click.echo(fmt.title("Upgrading from Ironwood")) click.echo(fmt.title("Upgrading from Ironwood"))
tutor_env.save(context.obj.root, config) tutor_env.save(context.obj.root, config)
@ -166,7 +166,7 @@ def upgrade_from_ironwood(context: click.Context, config: Dict[str, Any]) -> Non
context.invoke(compose.stop) context.invoke(compose.stop)
def upgrade_from_juniper(context: click.Context, config: Dict[str, Any]) -> None: def upgrade_from_juniper(context: click.Context, config: Config) -> None:
click.echo(fmt.title("Upgrading from Juniper")) click.echo(fmt.title("Upgrading from Juniper"))
tutor_env.save(context.obj.root, config) tutor_env.save(context.obj.root, config)

View File

@ -4,7 +4,7 @@ import platform
import subprocess import subprocess
import sys import sys
import tarfile import tarfile
from typing import Any, Dict, Optional from typing import Dict, Optional
from urllib.request import urlopen from urllib.request import urlopen
import click import click
@ -15,6 +15,7 @@ from .. import fmt
from .. import env as tutor_env from .. import env as tutor_env
from .. import exceptions from .. import exceptions
from .. import serialize from .. import serialize
from ..types import Config
from .context import Context from .context import Context
@ -135,7 +136,7 @@ def load_config(root: str) -> Dict[str, Optional[str]]:
return config return config
def save_webui_config_file(root: str, config: Dict[str, Any]) -> None: def save_webui_config_file(root: str, config: Config) -> None:
path = config_path(root) path = config_path(root)
directory = os.path.dirname(path) directory = os.path.dirname(path)
if not os.path.exists(directory): if not os.path.exists(directory):

View File

@ -1,15 +1,11 @@
import os import os
from typing import cast, Dict, Any, Tuple from typing import Tuple
from . import exceptions from . import env, exceptions, fmt, plugins, serialize, utils
from . import env from .types import Config, cast_config
from . import fmt
from . import plugins
from . import serialize
from . import utils
def update(root: str) -> Dict[str, Any]: def update(root: str) -> Config:
""" """
Load and save the configuration. Load and save the configuration.
""" """
@ -19,7 +15,7 @@ def update(root: str) -> Dict[str, Any]:
return config return config
def load(root: str) -> Dict[str, Any]: def load(root: str) -> Config:
""" """
Load full configuration. This will raise an exception if there is no current Load full configuration. This will raise an exception if there is no current
configuration in the project root. configuration in the project root.
@ -28,13 +24,13 @@ def load(root: str) -> Dict[str, Any]:
return load_no_check(root) return load_no_check(root)
def load_no_check(root: str) -> Dict[str, Any]: def load_no_check(root: str) -> Config:
config, defaults = load_all(root) config, defaults = load_all(root)
merge(config, defaults) merge(config, defaults)
return config return config
def load_all(root: str) -> Tuple[Dict[str, Any], Dict[str, Any]]: def load_all(root: str) -> Tuple[Config, Config]:
""" """
Return: Return:
current (dict): params currently saved in config.yml current (dict): params currently saved in config.yml
@ -46,9 +42,7 @@ def load_all(root: str) -> Tuple[Dict[str, Any], Dict[str, Any]]:
return current, defaults return current, defaults
def merge( def merge(config: Config, defaults: Config, force: bool = False) -> None:
config: Dict[str, str], defaults: Dict[str, str], force: bool = False
) -> None:
""" """
Merge default values with user configuration and perform rendering of "{{...}}" Merge default values with user configuration and perform rendering of "{{...}}"
values. values.
@ -58,23 +52,18 @@ def merge(
config[key] = env.render_unknown(config, value) config[key] = env.render_unknown(config, value)
def load_defaults() -> Dict[str, Any]: def load_defaults() -> Config:
config = serialize.load(env.read_template_file("config.yml")) config = serialize.load(env.read_template_file("config.yml"))
return cast(Dict[str, Any], config) return cast_config(config)
def load_config_file(path: str) -> Dict[str, Any]: def load_config_file(path: str) -> Config:
with open(path) as f: with open(path) as f:
config = serialize.load(f.read()) config = serialize.load(f.read())
if not isinstance(config, dict): return cast_config(config)
raise exceptions.TutorError(
"Invalid configuration: expected dict, got {}".format(config.__class__)
)
return config
def load_current(root: str, defaults: Dict[str, str]) -> Dict[str, Any]: def load_current(root: str, defaults: Config) -> Config:
""" """
Load the configuration currently stored on disk. Load the configuration currently stored on disk.
Note: this modifies the defaults with the plugin default values. Note: this modifies the defaults with the plugin default values.
@ -87,7 +76,7 @@ def load_current(root: str, defaults: Dict[str, str]) -> Dict[str, Any]:
return config return config
def load_user(root: str) -> Dict[str, Any]: def load_user(root: str) -> Config:
path = config_path(root) path = config_path(root)
if not os.path.exists(path): if not os.path.exists(path):
return {} return {}
@ -97,14 +86,14 @@ def load_user(root: str) -> Dict[str, Any]:
return config return config
def load_env(config: Dict[str, str], defaults: Dict[str, str]) -> None: def load_env(config: Config, defaults: Config) -> None:
for k in defaults.keys(): for k in defaults.keys():
env_var = "TUTOR_" + k env_var = "TUTOR_" + k
if env_var in os.environ: if env_var in os.environ:
config[k] = serialize.parse(os.environ[env_var]) config[k] = serialize.parse(os.environ[env_var])
def load_required(config: Dict[str, str], defaults: Dict[str, str]) -> None: def load_required(config: Config, defaults: Config) -> None:
""" """
All these keys must be present in the user's config.yml. This includes all values All these keys must be present in the user's config.yml. This includes all values
that are generated once and must be kept after that, such as passwords. that are generated once and must be kept after that, such as passwords.
@ -121,7 +110,7 @@ def load_required(config: Dict[str, str], defaults: Dict[str, str]) -> None:
config[key] = env.render_unknown(config, defaults[key]) config[key] = env.render_unknown(config, defaults[key])
def load_plugins(config: Dict[str, str], defaults: Dict[str, str]) -> None: def load_plugins(config: Config, defaults: Config) -> None:
""" """
Add, override and set new defaults from plugins. Add, override and set new defaults from plugins.
""" """
@ -143,11 +132,11 @@ def load_plugins(config: Dict[str, str], defaults: Dict[str, str]) -> None:
config[key] = env.render_unknown(config, value) config[key] = env.render_unknown(config, value)
def is_service_activated(config: Dict[str, Any], service: str) -> bool: def is_service_activated(config: Config, service: str) -> bool:
return config["RUN_" + service.upper()] is not False return config["RUN_" + service.upper()] is not False
def upgrade_obsolete(config: Dict[str, Any]) -> None: def upgrade_obsolete(config: Config) -> None:
# Openedx-specific mysql passwords # Openedx-specific mysql passwords
if "MYSQL_PASSWORD" in config: if "MYSQL_PASSWORD" in config:
config["MYSQL_ROOT_PASSWORD"] = config["MYSQL_PASSWORD"] config["MYSQL_ROOT_PASSWORD"] = config["MYSQL_PASSWORD"]
@ -209,7 +198,7 @@ def convert_json2yml(root: str) -> None:
) )
def save_config_file(root: str, config: Dict[str, str]) -> None: def save_config_file(root: str, config: Config) -> None:
path = config_path(root) path = config_path(root)
utils.ensure_file_directory_exists(path) utils.ensure_file_directory_exists(path)
with open(path, "w") as of: with open(path, "w") as of:

View File

@ -1,17 +1,14 @@
import codecs import codecs
from copy import deepcopy
import os import os
from typing import Dict, Any, Iterable, List, Optional, Type, Union from copy import deepcopy
from typing import Any, Iterable, List, Optional, Type, Union
import jinja2 import jinja2
import pkg_resources import pkg_resources
from . import exceptions from . import exceptions, fmt, plugins, utils
from . import fmt
from . import plugins
from . import utils
from .__about__ import __version__ from .__about__ import __version__
from .types import Config
TEMPLATES_ROOT = pkg_resources.resource_filename("tutor", "templates") TEMPLATES_ROOT = pkg_resources.resource_filename("tutor", "templates")
VERSION_FILENAME = "version" VERSION_FILENAME = "version"
@ -20,7 +17,7 @@ BIN_FILE_EXTENSIONS = [".ico", ".jpg", ".png", ".ttf", ".woff", ".woff2"]
class Renderer: class Renderer:
@classmethod @classmethod
def instance(cls: Type["Renderer"], config: Dict[str, Any]) -> "Renderer": def instance(cls: Type["Renderer"], config: Config) -> "Renderer":
# Load template roots: these are required to be able to use # Load template roots: these are required to be able to use
# {% include .. %} directives # {% include .. %} directives
template_roots = [TEMPLATES_ROOT] template_roots = [TEMPLATES_ROOT]
@ -32,7 +29,7 @@ class Renderer:
def __init__( def __init__(
self, self,
config: Dict[str, Any], config: Config,
template_roots: List[str], template_roots: List[str],
ignore_folders: Optional[List[str]] = None, ignore_folders: Optional[List[str]] = None,
): ):
@ -64,7 +61,10 @@ class Renderer:
The elements of `prefix` must contain only "/", and not os.sep. The elements of `prefix` must contain only "/", and not os.sep.
""" """
full_prefix = "/".join(prefix) full_prefix = "/".join(prefix)
for template in self.environment.loader.list_templates(): # type: ignore env_templates: List[
str
] = self.environment.loader.list_templates() # type:ignore[no-untyped-call]
for template in env_templates:
if template.startswith(full_prefix) and self.is_part_of_env(template): if template.startswith(full_prefix) and self.is_part_of_env(template):
yield template yield template
@ -171,7 +171,7 @@ class Renderer:
) )
def save(root: str, config: Dict[str, Any]) -> None: def save(root: str, config: Config) -> None:
""" """
Save the full environment, including version information. Save the full environment, including version information.
""" """
@ -206,7 +206,7 @@ def upgrade_obsolete(root: str) -> None:
def save_plugin_templates( def save_plugin_templates(
plugin: plugins.BasePlugin, root: str, config: Dict[str, Any] plugin: plugins.BasePlugin, root: str, config: Config
) -> None: ) -> None:
""" """
Save plugin templates to plugins/<plugin name>/*. Save plugin templates to plugins/<plugin name>/*.
@ -218,7 +218,7 @@ def save_plugin_templates(
save_all_from(subdir_path, plugins_root, config) save_all_from(subdir_path, plugins_root, config)
def save_all_from(prefix: str, root: str, config: Dict[str, Any]) -> None: def save_all_from(prefix: str, root: str, config: Config) -> None:
""" """
Render the templates that start with `prefix` and store them with the same Render the templates that start with `prefix` and store them with the same
hierarchy at `root`. Here, `prefix` can be the result of os.path.join(...). hierarchy at `root`. Here, `prefix` can be the result of os.path.join(...).
@ -240,7 +240,7 @@ def write_to(content: Union[str, bytes], path: str) -> None:
of_text.write(content) of_text.write(content)
def render_file(config: Dict[str, Any], *path: str) -> Union[str, bytes]: def render_file(config: Config, *path: str) -> Union[str, bytes]:
""" """
Return the rendered contents of a template. Return the rendered contents of a template.
""" """
@ -249,7 +249,7 @@ def render_file(config: Dict[str, Any], *path: str) -> Union[str, bytes]:
return renderer.render_template(template_name) return renderer.render_template(template_name)
def render_dict(config: Dict[str, Any]) -> None: def render_dict(config: Config) -> None:
""" """
Render the values from the dict. This is useful for rendering the default Render the values from the dict. This is useful for rendering the default
values from config.yml. values from config.yml.
@ -257,7 +257,7 @@ def render_dict(config: Dict[str, Any]) -> None:
Args: Args:
config (dict) config (dict)
""" """
rendered = {} rendered: Config = {}
for key, value in config.items(): for key, value in config.items():
if isinstance(value, str): if isinstance(value, str):
rendered[key] = render_str(config, value) rendered[key] = render_str(config, value)
@ -267,13 +267,13 @@ def render_dict(config: Dict[str, Any]) -> None:
config[k] = v config[k] = v
def render_unknown(config: Dict[str, Any], value: Any) -> Any: def render_unknown(config: Config, value: Any) -> Any:
if isinstance(value, str): if isinstance(value, str):
return render_str(config, value) return render_str(config, value)
return value return value
def render_str(config: Dict[str, Any], text: str) -> str: def render_str(config: Config, text: str) -> str:
""" """
Args: Args:
text (str) text (str)

View File

@ -1,11 +1,10 @@
from typing import Any, Dict from . import fmt, utils
from .types import Config, get_typed
from . import fmt
from . import utils
def get_tag(config: Dict[str, Any], name: str) -> Any: def get_tag(config: Config, name: str) -> str:
return config["DOCKER_IMAGE_" + name.upper().replace("-", "_")] key = "DOCKER_IMAGE_" + name.upper().replace("-", "_")
return get_typed(config, key, str)
def build(path: str, tag: str, *args: str) -> None: def build(path: str, tag: str, *args: str) -> None:

View File

@ -1,14 +1,14 @@
from typing import Any, Dict, List, Tuple from typing import List, Tuple
import click import click
from . import config as tutor_config from . import config as tutor_config
from . import env from . import env, exceptions, fmt
from . import exceptions
from . import fmt
from .__about__ import __version__ from .__about__ import __version__
from .types import Config, get_typed
def update(root: str, interactive: bool = True) -> Dict[str, Any]: def update(root: str, interactive: bool = True) -> Config:
""" """
Load and save the configuration. Load and save the configuration.
""" """
@ -18,9 +18,7 @@ def update(root: str, interactive: bool = True) -> Dict[str, Any]:
return config return config
def load_all( def load_all(root: str, interactive: bool = True) -> Tuple[Config, Config]:
root: str, interactive: bool = True
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
""" """
Load configuration and interactively ask questions to collect param values from the user. Load configuration and interactively ask questions to collect param values from the user.
""" """
@ -30,7 +28,7 @@ def load_all(
return config, defaults return config, defaults
def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None: def ask_questions(config: Config, defaults: Config) -> None:
run_for_prod = config.get("LMS_HOST") != "local.overhang.io" run_for_prod = config.get("LMS_HOST") != "local.overhang.io"
run_for_prod = click.confirm( run_for_prod = click.confirm(
fmt.question( fmt.question(
@ -40,7 +38,7 @@ def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None:
default=run_for_prod, default=run_for_prod,
) )
if not run_for_prod: if not run_for_prod:
dev_values = { dev_values: Config = {
"LMS_HOST": "local.overhang.io", "LMS_HOST": "local.overhang.io",
"CMS_HOST": "studio.local.overhang.io", "CMS_HOST": "studio.local.overhang.io",
"ENABLE_HTTPS": False, "ENABLE_HTTPS": False,
@ -54,7 +52,8 @@ def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None:
if run_for_prod: if run_for_prod:
ask("Your website domain name for students (LMS)", "LMS_HOST", config, defaults) ask("Your website domain name for students (LMS)", "LMS_HOST", config, defaults)
if "localhost" in config["LMS_HOST"]: lms_host = get_typed(config, "LMS_HOST", str)
if "localhost" in lms_host:
raise exceptions.TutorError( raise exceptions.TutorError(
"You may not use 'localhost' as the LMS domain name. To run a local platform for testing purposes you should answer 'n' to the previous question." "You may not use 'localhost' as the LMS domain name. To run a local platform for testing purposes you should answer 'n' to the previous question."
) )
@ -159,19 +158,18 @@ def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None:
) )
def ask( def ask(question: str, key: str, config: Config, defaults: Config) -> None:
question: str, key: str, config: Dict[str, Any], defaults: Dict[str, Any] default = get_typed(defaults, key, str)
) -> None: default = get_typed(config, key, str, default=default)
default = env.render_str(config, config.get(key, defaults[key])) default = env.render_str(config, default)
config[key] = click.prompt( config[key] = click.prompt(
fmt.question(question), prompt_suffix=" ", default=default, show_default=True fmt.question(question), prompt_suffix=" ", default=default, show_default=True
) )
def ask_bool( def ask_bool(question: str, key: str, config: Config, defaults: Config) -> None:
question: str, key: str, config: Dict[str, Any], defaults: Dict[str, Any] default = get_typed(defaults, key, bool)
) -> None: default = get_typed(config, key, bool, default=default)
default = config.get(key, defaults[key])
config[key] = click.confirm( config[key] = click.confirm(
fmt.question(question), prompt_suffix=" ", default=default fmt.question(question), prompt_suffix=" ", default=default
) )
@ -180,11 +178,11 @@ def ask_bool(
def ask_choice( def ask_choice(
question: str, question: str,
key: str, key: str,
config: Dict[str, Any], config: Config,
defaults: Dict[str, Any], defaults: Config,
choices: List[str], choices: List[str],
) -> None: ) -> None:
default = config.get(key, defaults[key]) default = str(config.get(key, defaults[key]))
answer = click.prompt( answer = click.prompt(
fmt.question(question), fmt.question(question),
type=click.Choice(choices), type=click.Choice(choices),

View File

@ -1,8 +1,7 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Dict, Iterator, List, Optional, Tuple, Union
from . import env from . import env, fmt, plugins
from . import fmt from .types import Config
from . import plugins
BASE_OPENEDX_COMMAND = """ BASE_OPENEDX_COMMAND = """
export DJANGO_SETTINGS_MODULE=$SERVICE_VARIANT.envs.$SETTINGS export DJANGO_SETTINGS_MODULE=$SERVICE_VARIANT.envs.$SETTINGS
@ -11,7 +10,7 @@ echo "Loading settings $DJANGO_SETTINGS_MODULE"
class BaseJobRunner: class BaseJobRunner:
def __init__(self, root: str, config: Dict[str, Any]): def __init__(self, root: str, config: Config):
self.root = root self.root = root
self.config = config self.config = config

View File

@ -1,18 +1,15 @@
from collections import namedtuple
from copy import deepcopy
from glob import glob
import importlib import importlib
import os import os
from copy import deepcopy
from glob import glob
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
import appdirs import appdirs
import click import click
import pkg_resources import pkg_resources
from . import exceptions from . import exceptions, fmt, serialize
from . import fmt from .types import Config, get_typed
from . import serialize
CONFIG_KEY = "PLUGINS" CONFIG_KEY = "PLUGINS"
@ -69,7 +66,7 @@ class BasePlugin:
self.command: click.Command = command self.command: click.Command = command
@staticmethod @staticmethod
def load_config(obj: Any, plugin_name: str) -> Dict[str, Dict[str, Any]]: def load_config(obj: Any, plugin_name: str) -> Dict[str, Config]:
""" """
Load config and check types. Load config and check types.
""" """
@ -181,15 +178,15 @@ class BasePlugin:
return self.name.upper() + "_" + key return self.name.upper() + "_" + key
@property @property
def config_add(self) -> Dict[str, Any]: def config_add(self) -> Config:
return self.config.get("add", {}) return self.config.get("add", {})
@property @property
def config_set(self) -> Dict[str, Any]: def config_set(self) -> Config:
return self.config.get("set", {}) return self.config.get("set", {})
@property @property
def config_defaults(self) -> Dict[str, Any]: def config_defaults(self) -> Config:
return self.config.get("defaults", {}) return self.config.get("defaults", {})
@property @property
@ -269,16 +266,32 @@ class DictPlugin(BasePlugin):
os.environ.get(ROOT_ENV_VAR_NAME, "") os.environ.get(ROOT_ENV_VAR_NAME, "")
) or appdirs.user_data_dir(appname="tutor-plugins") ) or appdirs.user_data_dir(appname="tutor-plugins")
def __init__(self, data: Dict[str, Any]): def __init__(self, data: Config):
Module = namedtuple("Module", data.keys()) # type: ignore name = data["name"]
obj = Module(**data) # type: ignore if not isinstance(name, str):
super().__init__(data["name"], obj) raise exceptions.TutorError(
self._version = data["version"] "Invalid plugin name: '{}'. Expected str, got {}".format(
name, name.__class__
)
)
# Create a generic object (sort of a named tuple) which will contain all key/values from data
class Module:
pass
obj = Module()
for key, value in data.items():
setattr(obj, key, value)
super().__init__(name, obj)
version = data["version"]
if not isinstance(version, str):
raise TypeError("DictPlugin.__version__ must be str")
self._version: str = version
@property @property
def version(self) -> str: def version(self) -> str:
if not isinstance(self._version, str):
raise TypeError("DictPlugin.__version__ must be str")
return self._version return self._version
@classmethod @classmethod
@ -305,7 +318,7 @@ class Plugins:
DictPlugin, DictPlugin,
] ]
def __init__(self, config: Dict[str, Any]): def __init__(self, config: Config):
self.config = deepcopy(config) self.config = deepcopy(config)
# patches has the following structure: # patches has the following structure:
# {patch_name -> {plugin_name -> "content"}} # {patch_name -> {plugin_name -> "content"}}
@ -380,18 +393,17 @@ def iter_installed() -> Iterator[BasePlugin]:
yield from Plugins.iter_installed() yield from Plugins.iter_installed()
def enable(config: Dict[str, Any], name: str) -> None: def enable(config: Config, name: str) -> None:
if not is_installed(name): if not is_installed(name):
raise exceptions.TutorError("plugin '{}' is not installed.".format(name)) raise exceptions.TutorError("plugin '{}' is not installed.".format(name))
if is_enabled(config, name): if is_enabled(config, name):
return return
if CONFIG_KEY not in config: enabled = enabled_plugins(config)
config[CONFIG_KEY] = [] enabled.append(name)
config[CONFIG_KEY].append(name) enabled.sort()
config[CONFIG_KEY].sort()
def disable(config: Dict[str, Any], name: str) -> None: def disable(config: Config, name: str) -> None:
fmt.echo_info("Disabling plugin {}...".format(name)) fmt.echo_info("Disabling plugin {}...".format(name))
for plugin in Plugins(config).iter_enabled(): for plugin in Plugins(config).iter_enabled():
if name == plugin.name: if name == plugin.name:
@ -400,25 +412,32 @@ def disable(config: Dict[str, Any], name: str) -> None:
config.pop(key, None) config.pop(key, None)
fmt.echo_info(" Removed config entry {}={}".format(key, value)) fmt.echo_info(" Removed config entry {}={}".format(key, value))
# Remove plugin from list # Remove plugin from list
while name in config[CONFIG_KEY]: enabled = enabled_plugins(config)
config[CONFIG_KEY].remove(name) while name in enabled:
enabled.remove(name)
fmt.echo_info(" Plugin disabled") fmt.echo_info(" Plugin disabled")
def iter_enabled(config: Dict[str, Any]) -> Iterator[BasePlugin]: def iter_enabled(config: Config) -> Iterator[BasePlugin]:
yield from Plugins(config).iter_enabled() yield from Plugins(config).iter_enabled()
def is_enabled(config: Dict[str, Any], name: str) -> bool: def is_enabled(config: Config, name: str) -> bool:
plugin_list = config.get(CONFIG_KEY) or [] return name in enabled_plugins(config)
return name in plugin_list
def iter_patches(config: Dict[str, str], name: str) -> Iterator[Tuple[str, str]]: def enabled_plugins(config: Config) -> List[str]:
if not config.get(CONFIG_KEY):
config[CONFIG_KEY] = []
plugins = get_typed(config, CONFIG_KEY, list)
return plugins
def iter_patches(config: Config, name: str) -> Iterator[Tuple[str, str]]:
yield from Plugins(config).iter_patches(name) yield from Plugins(config).iter_patches(name)
def iter_hooks( def iter_hooks(
config: Dict[str, Any], hook_name: str config: Config, hook_name: str
) -> Iterator[Tuple[str, Union[Dict[str, str], List[str]]]]: ) -> Iterator[Tuple[str, Union[Dict[str, str], List[str]]]]:
yield from Plugins(config).iter_hooks(hook_name) yield from Plugins(config).iter_hooks(hook_name)

View File

@ -1,13 +1,12 @@
import re import re
from typing import Any, IO, Iterator, Tuple, Union from typing import IO, Any, Iterator, Tuple, Union
import click
import yaml import yaml
from _io import TextIOWrapper from _io import TextIOWrapper
from yaml.parser import ParserError from yaml.parser import ParserError
from yaml.scanner import ScannerError from yaml.scanner import ScannerError
import click
def load(stream: Union[str, IO[str]]) -> Any: def load(stream: Union[str, IO[str]]) -> Any:
return yaml.load(stream, Loader=yaml.SafeLoader) return yaml.load(stream, Loader=yaml.SafeLoader)
@ -21,6 +20,12 @@ def dump(content: Any, fileobj: TextIOWrapper) -> None:
yaml.dump(content, stream=fileobj, default_flow_style=False) yaml.dump(content, stream=fileobj, default_flow_style=False)
def dumps(content: Any) -> str:
result = yaml.dump(content, default_flow_style=False)
assert isinstance(result, str)
return result
def parse(v: Union[str, IO[str]]) -> Any: def parse(v: Union[str, IO[str]]) -> Any:
""" """
Parse a yaml-formatted string. Parse a yaml-formatted string.

39
tutor/types.py Normal file
View File

@ -0,0 +1,39 @@
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from . import exceptions
ConfigValue = Union[
str, float, None, bool, List[str], List[Any], Dict[str, Any], Dict[Any, Any]
]
Config = Dict[str, ConfigValue]
def cast_config(config: Any) -> Config:
if not isinstance(config, dict):
raise exceptions.TutorError(
"Invalid configuration: expected dict, got {}".format(config.__class__)
)
for key in config.keys():
if not isinstance(key, str):
raise exceptions.TutorError(
"Invalid configuration: expected str, got {} for key '{}'".format(
key.__class__, key
)
)
return config
T = TypeVar("T")
def get_typed(
config: Config, key: str, expected_type: Type[T], default: Optional[T] = None
) -> T:
value = config.get(key, default)
if not isinstance(value, expected_type):
raise exceptions.TutorError(
"Invalid config entry: expected {}, got {} for key '{}'".format(
expected_type.__name__, value.__class__, key
)
)
return value