diff --git a/tests/test_config.py b/tests/test_config.py index 9fc7911..3fab30b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,10 +1,10 @@ -from typing import Any, Dict import unittest from unittest.mock import Mock, patch import tempfile from tutor import config as tutor_config from tutor import interactive +from tutor.types import get_typed, Config class ConfigTests(unittest.TestCase): @@ -13,13 +13,13 @@ class ConfigTests(unittest.TestCase): self.assertNotIn("TUTOR_VERSION", defaults) def test_merge(self) -> None: - config1 = {"x": "y"} - config2 = {"x": "z"} + config1: Config = {"x": "y"} + config2: Config = {"x": "z"} tutor_config.merge(config1, config2) self.assertEqual({"x": "y"}, config1) def test_merge_render(self) -> None: - config: Dict[str, Any] = {} + config: Config = {} defaults = tutor_config.load_defaults() with patch.object(tutor_config.utils, "random_string", return_value="abcd"): tutor_config.merge(config, defaults) @@ -62,13 +62,13 @@ class ConfigTests(unittest.TestCase): config, defaults = interactive.load_all(rootdir, interactive=False) self.assertIn("MYSQL_ROOT_PASSWORD", config) - self.assertEqual(8, len(config["MYSQL_ROOT_PASSWORD"])) + self.assertEqual(8, len(get_typed(config, "MYSQL_ROOT_PASSWORD", str))) self.assertNotIn("LMS_HOST", config) self.assertEqual("www.myopenedx.com", defaults["LMS_HOST"]) self.assertEqual("studio.{{ LMS_HOST }}", defaults["CMS_HOST"]) def test_is_service_activated(self) -> None: - config = {"RUN_SERVICE1": True, "RUN_SERVICE2": False} + config: Config = {"RUN_SERVICE1": True, "RUN_SERVICE2": False} self.assertTrue(tutor_config.is_service_activated(config, "service1")) self.assertFalse(tutor_config.is_service_activated(config, "service2")) diff --git a/tests/test_env.py b/tests/test_env.py index c49316f..e01c220 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -1,6 +1,5 @@ import os import tempfile -from typing import Any, Dict import unittest from unittest.mock import patch, Mock @@ -8,6 +7,7 @@ from tutor import config as tutor_config from tutor import env from tutor import fmt from tutor import exceptions +from tutor.types import Config class EnvTests(unittest.TestCase): @@ -55,7 +55,7 @@ class EnvTests(unittest.TestCase): self.assertRaises(exceptions.TutorError, env.render_str, {}, "hello {{ name }}") def test_render_file(self) -> None: - config: Dict[str, Any] = {} + config: Config = {} tutor_config.merge(config, tutor_config.load_defaults()) config["MYSQL_ROOT_PASSWORD"] = "testpassword" rendered = env.render_file(config, "hooks", "mysql", "init") @@ -125,7 +125,7 @@ class EnvTests(unittest.TestCase): f.write("Hello my ID is {{ ID }}") # Create configuration - config = {"ID": "abcd"} + config: Config = {"ID": "abcd"} # Render templates with patch.object( @@ -162,7 +162,7 @@ class EnvTests(unittest.TestCase): f.write("some content") # Load env once - config: Dict[str, Any] = {"PLUGINS": []} + config: Config = {"PLUGINS": []} env1 = env.Renderer.instance(config).environment with patch.object( @@ -171,7 +171,7 @@ class EnvTests(unittest.TestCase): return_value=[plugin1], ): # Load env a second time - config["PLUGINS"].append("myplugin") + config["PLUGINS"] = ["myplugin"] env2 = env.Renderer.instance(config).environment self.assertNotIn("plugin1/myplugin.txt", env1.loader.list_templates()) # type: ignore diff --git a/tests/test_images.py b/tests/test_images.py index 149866d..5a0759d 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -1,10 +1,11 @@ import unittest from tutor import images +from tutor.types import Config class ImagesTests(unittest.TestCase): def test_get_tag(self) -> None: - config = { + config: Config = { "DOCKER_IMAGE_OPENEDX": "registry/openedx", "DOCKER_IMAGE_OPENEDX_DEV": "registry/openedxdev", } diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 3a1a41d..8b32450 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,4 +1,3 @@ -from typing import Any, Dict import unittest from unittest.mock import Mock, patch @@ -6,6 +5,7 @@ from tutor import config as tutor_config from tutor import exceptions from tutor import fmt from tutor import plugins +from tutor.types import get_typed, Config class PluginsTests(unittest.TestCase): @@ -37,26 +37,26 @@ class PluginsTests(unittest.TestCase): ) def test_enable(self) -> None: - config: Dict[str, Any] = {plugins.CONFIG_KEY: []} + config: Config = {plugins.CONFIG_KEY: []} with patch.object(plugins, "is_installed", return_value=True): plugins.enable(config, "plugin2") plugins.enable(config, "plugin1") self.assertEqual(["plugin1", "plugin2"], config[plugins.CONFIG_KEY]) def test_enable_twice(self) -> None: - config: Dict[str, Any] = {plugins.CONFIG_KEY: []} + config: Config = {plugins.CONFIG_KEY: []} with patch.object(plugins, "is_installed", return_value=True): plugins.enable(config, "plugin1") plugins.enable(config, "plugin1") self.assertEqual(["plugin1"], config[plugins.CONFIG_KEY]) def test_enable_not_installed_plugin(self) -> None: - config: Dict[str, Any] = {"PLUGINS": []} + config: Config = {"PLUGINS": []} with patch.object(plugins, "is_installed", return_value=False): self.assertRaises(exceptions.TutorError, plugins.enable, config, "plugin1") def test_disable(self) -> None: - config: Dict[str, Any] = {"PLUGINS": ["plugin1", "plugin2"]} + config: Config = {"PLUGINS": ["plugin1", "plugin2"]} with patch.object(fmt, "STDOUT"): plugins.disable(config, "plugin1") self.assertEqual(["plugin2"], config["PLUGINS"]) @@ -75,14 +75,14 @@ class PluginsTests(unittest.TestCase): ) ], ): - config = {"PLUGINS": ["plugin1"], "KEY": "value"} + config: Config = {"PLUGINS": ["plugin1"], "KEY": "value"} with patch.object(fmt, "STDOUT"): plugins.disable(config, "plugin1") self.assertEqual([], config["PLUGINS"]) self.assertNotIn("KEY", config) def test_none_plugins(self) -> None: - config = {plugins.CONFIG_KEY: None} + config: Config = {plugins.CONFIG_KEY: None} self.assertFalse(plugins.is_enabled(config, "myplugin")) def test_patches(self) -> None: @@ -107,11 +107,11 @@ class PluginsTests(unittest.TestCase): self.assertEqual([], patches) def test_configure(self) -> None: - config = {"ID": "id"} - defaults: Dict[str, Any] = {} + config: Config = {"ID": "id"} + defaults: Config = {} class plugin1: - config = { + config: Config = { "add": {"PARAM1": "value1", "PARAM2": "value2"}, "set": {"PARAM3": "value3"}, "defaults": {"PARAM4": "value4"}, @@ -136,10 +136,10 @@ class PluginsTests(unittest.TestCase): self.assertEqual({"PLUGIN1_PARAM4": "value4"}, defaults) def test_configure_set_does_not_override(self) -> None: - config = {"ID": "oldid"} + config: Config = {"ID": "oldid"} class plugin1: - config = {"set": {"ID": "newid"}} + config: Config = {"set": {"ID": "newid"}} with patch.object( plugins.Plugins, @@ -151,10 +151,10 @@ class PluginsTests(unittest.TestCase): self.assertEqual({"ID": "oldid"}, config) def test_configure_set_random_string(self) -> None: - config: Dict[str, Any] = {} + config: Config = {} class plugin1: - config = {"set": {"PARAM1": "{{ 128|random_string }}"}} + config: Config = {"set": {"PARAM1": "{{ 128|random_string }}"}} with patch.object( plugins.Plugins, @@ -162,14 +162,14 @@ class PluginsTests(unittest.TestCase): return_value=[plugins.BasePlugin("plugin1", plugin1)], ): tutor_config.load_plugins(config, {}) - self.assertEqual(128, len(config["PARAM1"])) + self.assertEqual(128, len(get_typed(config, "PARAM1", str))) def test_configure_default_value_with_previous_definition(self) -> None: - config: Dict[str, Any] = {} - defaults = {"PARAM1": "value"} + config: Config = {} + defaults: Config = {"PARAM1": "value"} class plugin1: - config = {"defaults": {"PARAM2": "{{ PARAM1 }}"}} + config: Config = {"defaults": {"PARAM2": "{{ PARAM1 }}"}} with patch.object( plugins.Plugins, @@ -180,10 +180,10 @@ class PluginsTests(unittest.TestCase): self.assertEqual("{{ PARAM1 }}", defaults["PLUGIN1_PARAM2"]) def test_configure_add_twice(self) -> None: - config: Dict[str, Any] = {} + config: Config = {} class plugin1: - config = {"add": {"PARAM1": "{{ 10|random_string }}"}} + config: Config = {"add": {"PARAM1": "{{ 10|random_string }}"}} with patch.object( plugins.Plugins, @@ -191,14 +191,14 @@ class PluginsTests(unittest.TestCase): return_value=[plugins.BasePlugin("plugin1", plugin1)], ): tutor_config.load_plugins(config, {}) - value1 = config["PLUGIN1_PARAM1"] + value1 = get_typed(config, "PLUGIN1_PARAM1", str) with patch.object( plugins.Plugins, "iter_enabled", return_value=[plugins.BasePlugin("plugin1", plugin1)], ): tutor_config.load_plugins(config, {}) - value2 = config["PLUGIN1_PARAM1"] + value2 = get_typed(config, "PLUGIN1_PARAM1", str) self.assertEqual(10, len(value1)) self.assertEqual(10, len(value2)) @@ -218,10 +218,10 @@ class PluginsTests(unittest.TestCase): ) def test_plugins_are_updated_on_config_change(self) -> None: - config: Dict[str, Any] = {"PLUGINS": []} + config: Config = {"PLUGINS": []} plugins1 = plugins.Plugins(config) self.assertEqual(0, len(list(plugins1.iter_enabled()))) - config["PLUGINS"].append("plugin1") + config["PLUGINS"] = ["plugin1"] with patch.object( plugins.Plugins, "iter_installed", diff --git a/tutor/bindmounts.py b/tutor/bindmounts.py index 56534a5..d733283 100644 --- a/tutor/bindmounts.py +++ b/tutor/bindmounts.py @@ -1,17 +1,18 @@ import os -from typing import Any, Callable, Dict, List, Tuple +from typing import Callable, List, Tuple import click from mypy_extensions import VarArg from .exceptions import TutorError +from .types import Config from .utils import get_user_id def create( root: str, - config: Dict[str, Any], - docker_compose_func: Callable[[str, Dict[str, Any], VarArg(str)], int], + config: Config, + docker_compose_func: Callable[[str, Config, VarArg(str)], int], service: str, path: str, ) -> str: diff --git a/tutor/commands/android.py b/tutor/commands/android.py index 2e4b2a8..90fddf6 100644 --- a/tutor/commands/android.py +++ b/tutor/commands/android.py @@ -1,5 +1,3 @@ -from typing import Dict - import click from .compose import ComposeJobRunner @@ -7,6 +5,7 @@ from .local import docker_compose as local_docker_compose from .. import config as tutor_config from .. import env as tutor_env from .. import fmt +from ..types import Config from .context import Context @@ -28,7 +27,7 @@ def build(context: Context, mode: str) -> None: ) -def build_command(config: Dict[str, str], target: str) -> str: +def build_command(config: Config, target: str) -> str: gradle_target = { "debug": "assembleProdDebuggable", "release": "assembleProdRelease", diff --git a/tutor/commands/compose.py b/tutor/commands/compose.py index 4540949..6860802 100644 --- a/tutor/commands/compose.py +++ b/tutor/commands/compose.py @@ -1,5 +1,5 @@ import os -from typing import Any, Callable, Dict, List +from typing import Callable, List import click from mypy_extensions import VarArg @@ -11,6 +11,7 @@ from ..exceptions import TutorError from .. import fmt from .. import jobs from .. import serialize +from ..types import Config from .. import utils from .context import Context @@ -19,8 +20,8 @@ class ComposeJobRunner(jobs.BaseJobRunner): def __init__( self, root: str, - config: Dict[str, Any], - docker_compose_func: Callable[[str, Dict[str, Any], VarArg(str)], int], + config: Config, + docker_compose_func: Callable[[str, Config, VarArg(str)], int], ): super().__init__(root, config) self.docker_compose_func = docker_compose_func diff --git a/tutor/commands/config.py b/tutor/commands/config.py index 17b9c62..320df26 100644 --- a/tutor/commands/config.py +++ b/tutor/commands/config.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import List import click @@ -8,6 +8,7 @@ from .. import exceptions from .. import fmt from .. import interactive as interactive_config from .. import serialize +from ..types import Config from .context import Context @@ -40,7 +41,7 @@ def config_command() -> None: ) @click.pass_obj def save( - context: Context, interactive: bool, set_vars: Dict[str, Any], unset_vars: List[str] + context: Context, interactive: bool, set_vars: Config, unset_vars: List[str] ) -> None: config, defaults = interactive_config.load_all( context.root, interactive=interactive @@ -91,7 +92,7 @@ def printvalue(context: Context, key: str) -> None: config = tutor_config.load(context.root) try: # Note that this will incorrectly print None values - fmt.echo(config[key]) + fmt.echo(str(config[key])) except KeyError as e: raise exceptions.TutorError( "Missing configuration value: {}".format(key) diff --git a/tutor/commands/context.py b/tutor/commands/context.py index 38530fb..3d24220 100644 --- a/tutor/commands/context.py +++ b/tutor/commands/context.py @@ -1,9 +1,7 @@ -from typing import Any, Dict +from ..types import Config -def unimplemented_docker_compose( - root: str, config: Dict[str, Any], *command: str -) -> int: +def unimplemented_docker_compose(root: str, config: Config, *command: str) -> int: raise NotImplementedError @@ -13,5 +11,5 @@ class Context: self.root = root self.docker_compose_func = unimplemented_docker_compose - def docker_compose(self, root: str, config: Dict[str, Any], *command: str) -> int: + def docker_compose(self, root: str, config: Config, *command: str) -> int: return self.docker_compose_func(root, config, *command) diff --git a/tutor/commands/dev.py b/tutor/commands/dev.py index 3935a8e..b867503 100644 --- a/tutor/commands/dev.py +++ b/tutor/commands/dev.py @@ -1,17 +1,18 @@ import os -from typing import Any, Dict, List +from typing import List import click from .. import config as tutor_config from .. import env as tutor_env from .. import fmt +from ..types import Config from .. import utils from . import compose from .context import Context -def docker_compose(root: str, config: Dict[str, Any], *command: str) -> int: +def docker_compose(root: str, config: Config, *command: str) -> int: """ Run docker-compose with dev arguments. """ diff --git a/tutor/commands/images.py b/tutor/commands/images.py index 3e01de4..2936efa 100644 --- a/tutor/commands/images.py +++ b/tutor/commands/images.py @@ -1,11 +1,13 @@ -from typing import cast, Any, Dict, Iterator, List, Tuple +from typing import Iterator, List, Tuple import click from .. import config as tutor_config from .. import env as tutor_env +from .. import exceptions from .. import images from .. import plugins +from ..types import Config from .. import utils from .context import Context @@ -105,7 +107,7 @@ def printtag(context: Context, image_names: List[str]) -> None: print(tag) -def build_image(root: str, config: Dict[str, Any], image: str, *args: str) -> None: +def build_image(root: str, config: Config, image: str, *args: str) -> None: # Build base images for img, tag in iter_images(config, image, BASE_IMAGE_NAMES): images.build(tutor_env.pathjoin(root, "build", img), tag, *args) @@ -122,14 +124,14 @@ def build_image(root: str, config: Dict[str, Any], image: str, *args: str) -> No images.build(tutor_env.pathjoin(root, "build", img), tag, *dev_build_arg, *args) -def pull_image(config: Dict[str, Any], image: str) -> None: +def pull_image(config: Config, image: str) -> None: for _img, tag in iter_images(config, image, all_image_names(config)): images.pull(tag) for _plugin, _img, tag in iter_plugin_images(config, image, "remote-image"): images.pull(tag) -def push_image(config: Dict[str, Any], image: str) -> None: +def push_image(config: Config, image: str) -> None: for _img, tag in iter_images(config, image, BASE_IMAGE_NAMES): images.push(tag) for _plugin, _img, tag in iter_plugin_images(config, image, "remote-image"): @@ -137,7 +139,7 @@ def push_image(config: Dict[str, Any], image: str) -> None: def iter_images( - config: Dict[str, Any], image: str, image_list: List[str] + config: Config, image: str, image_list: List[str] ) -> Iterator[Tuple[str, str]]: for img in image_list: if image in [img, "all"]: @@ -146,21 +148,26 @@ def iter_images( def iter_plugin_images( - config: Dict[str, Any], image: str, hook_name: str + config: Config, image: str, hook_name: str ) -> Iterator[Tuple[str, str, str]]: for plugin, hook in plugins.iter_hooks(config, hook_name): - hook = cast(Dict[str, str], hook) + if not isinstance(hook, dict): + raise exceptions.TutorError( + "Invalid hook '{}': expected dict, got {}".format( + hook_name, hook.__class__ + ) + ) for img, tag in hook.items(): if image in [img, "all"]: tag = tutor_env.render_str(config, tag) yield plugin, img, tag -def all_image_names(config: Dict[str, Any]) -> List[str]: +def all_image_names(config: Config) -> List[str]: return BASE_IMAGE_NAMES + vendor_image_names(config) -def vendor_image_names(config: Dict[str, Any]) -> List[str]: +def vendor_image_names(config: Config) -> List[str]: vendor_images = VENDOR_IMAGES[:] for image in VENDOR_IMAGES: if not config.get("RUN_" + image.upper(), True): diff --git a/tutor/commands/k8s.py b/tutor/commands/k8s.py index 2d859bf..7b03e05 100644 --- a/tutor/commands/k8s.py +++ b/tutor/commands/k8s.py @@ -1,6 +1,6 @@ from datetime import datetime from time import sleep -from typing import cast, Any, Dict, List, Optional, Type +from typing import Any, List, Optional, Type import click @@ -11,6 +11,7 @@ from .. import fmt from .. import interactive as interactive_config from .. import jobs from .. import serialize +from ..types import Config, get_typed from .. import utils from .context import Context @@ -50,7 +51,11 @@ class K8sJobRunner(jobs.BaseJobRunner): def load_job(self, name: str) -> Any: all_jobs = self.render("k8s", "jobs.yml") for job in serialize.load_all(all_jobs): - job_name = cast(str, job["metadata"]["name"]) + job_name = job["metadata"]["name"] + if not isinstance(job_name, str): + raise exceptions.TutorError( + "Invalid job name: '{}'. Expected str.".format(job_name) + ) if job_name == name: return job raise ValueError("Could not find job '{}'".format(name)) @@ -64,7 +69,7 @@ class K8sJobRunner(jobs.BaseJobRunner): api = K8sClients.instance().batch_api return [ job.metadata.name - for job in api.list_namespaced_job(self.config["K8S_NAMESPACE"]).items + for job in api.list_namespaced_job(k8s_namespace(self.config)).items if job.status.active ] @@ -139,7 +144,7 @@ class K8sJobRunner(jobs.BaseJobRunner): """ kubectl logs --namespace={namespace} --follow $(kubectl get --namespace={namespace} pods """ """--selector=job-name={job_name} -o=jsonpath="{{.items[0].metadata.name}}")\n\n""" "Waiting for job completion..." - ).format(job_name=job_name, namespace=self.config["K8S_NAMESPACE"]) + ).format(job_name=job_name, namespace=k8s_namespace(self.config)) fmt.echo_info(message) # Wait for completion @@ -257,14 +262,15 @@ def reboot(context: click.Context) -> None: context.invoke(start) -def resource_selector(config: Dict[str, str], *selectors: str) -> List[str]: +def resource_selector(config: Config, *selectors: str) -> List[str]: """ Convenient utility for filtering only the resources that belong to this project. """ selector = ",".join( - ["app.kubernetes.io/instance=openedx-" + config["ID"]] + list(selectors) + ["app.kubernetes.io/instance=openedx-" + get_typed(config, "ID", str)] + + list(selectors) ) - return ["--namespace", config["K8S_NAMESPACE"], "--selector=" + selector] + return ["--namespace", k8s_namespace(config), "--selector=" + selector] @click.command(help="Completely delete an existing platform") @@ -398,7 +404,7 @@ def upgrade(context: Context, from_version: str) -> None: running_version = "koa" -def upgrade_from_ironwood(config: Dict[str, Any]) -> None: +def upgrade_from_ironwood(config: Config) -> None: if not config["RUN_MONGODB"]: fmt.echo_info( "You are not running MongDB (RUN_MONGODB=false). It is your " @@ -425,7 +431,7 @@ your MongoDb cluster from v3.2 to v3.6. You should run something similar to: fmt.echo_info(message) -def upgrade_from_juniper(config: Dict[str, Any]) -> None: +def upgrade_from_juniper(config: Config) -> None: if not config["RUN_MYSQL"]: fmt.echo_info( "You are not running MySQL (RUN_MYSQL=false). It is your " @@ -446,7 +452,7 @@ your MySQL database from v5.6 to v5.7. You should run something similar to: def kubectl_exec( - config: Dict[str, Any], service: str, command: str, attach: bool = False + config: Config, service: str, command: str, attach: bool = False ) -> int: selector = "app.kubernetes.io/name={}".format(service) pods = K8sClients.instance().core_api.list_namespaced_pod( @@ -464,7 +470,7 @@ def kubectl_exec( "exec", *attach_opts, "--namespace", - config["K8S_NAMESPACE"], + k8s_namespace(config), pod_name, "--", "sh", @@ -474,7 +480,7 @@ def kubectl_exec( ) -def wait_for_pod_ready(config: Dict[str, str], service: str) -> None: +def wait_for_pod_ready(config: Config, service: str) -> None: fmt.echo_info("Waiting for a {} pod to be ready...".format(service)) utils.kubectl( "wait", @@ -485,6 +491,10 @@ def wait_for_pod_ready(config: Dict[str, str], service: str) -> None: ) +def k8s_namespace(config: Config) -> str: + return get_typed(config, "K8S_NAMESPACE", str) + + k8s.add_command(quickstart) k8s.add_command(start) k8s.add_command(stop) diff --git a/tutor/commands/local.py b/tutor/commands/local.py index 655c09e..fd89807 100644 --- a/tutor/commands/local.py +++ b/tutor/commands/local.py @@ -1,18 +1,18 @@ import os -from typing import Dict, Any import click from .. import config as tutor_config from .. import env as tutor_env from .. import fmt +from ..types import get_typed, Config from .. import utils from . import compose from .config import save as config_save_command from .context import Context -def docker_compose(root: str, config: Dict[str, Any], *command: str) -> int: +def docker_compose(root: str, config: Config, *command: str) -> int: """ Run docker-compose with local and production yml files. """ @@ -27,7 +27,7 @@ def docker_compose(root: str, config: Dict[str, Any], *command: str) -> int: tutor_env.pathjoin(root, "local", "docker-compose.prod.yml"), *args, "--project-name", - config["LOCAL_PROJECT_NAME"], + get_typed(config, "LOCAL_PROJECT_NAME", str), *command ) @@ -118,7 +118,7 @@ Are you sure you want to continue?""" running_version = "koa" -def upgrade_from_ironwood(context: click.Context, config: Dict[str, Any]) -> None: +def upgrade_from_ironwood(context: click.Context, config: Config) -> None: click.echo(fmt.title("Upgrading from Ironwood")) tutor_env.save(context.obj.root, config) @@ -166,7 +166,7 @@ def upgrade_from_ironwood(context: click.Context, config: Dict[str, Any]) -> Non context.invoke(compose.stop) -def upgrade_from_juniper(context: click.Context, config: Dict[str, Any]) -> None: +def upgrade_from_juniper(context: click.Context, config: Config) -> None: click.echo(fmt.title("Upgrading from Juniper")) tutor_env.save(context.obj.root, config) diff --git a/tutor/commands/webui.py b/tutor/commands/webui.py index 2d8b4da..9508b96 100644 --- a/tutor/commands/webui.py +++ b/tutor/commands/webui.py @@ -4,7 +4,7 @@ import platform import subprocess import sys import tarfile -from typing import Any, Dict, Optional +from typing import Dict, Optional from urllib.request import urlopen import click @@ -15,6 +15,7 @@ from .. import fmt from .. import env as tutor_env from .. import exceptions from .. import serialize +from ..types import Config from .context import Context @@ -135,7 +136,7 @@ def load_config(root: str) -> Dict[str, Optional[str]]: return config -def save_webui_config_file(root: str, config: Dict[str, Any]) -> None: +def save_webui_config_file(root: str, config: Config) -> None: path = config_path(root) directory = os.path.dirname(path) if not os.path.exists(directory): diff --git a/tutor/config.py b/tutor/config.py index 63d4e6c..c72de93 100644 --- a/tutor/config.py +++ b/tutor/config.py @@ -1,15 +1,11 @@ import os -from typing import cast, Dict, Any, Tuple +from typing import Tuple -from . import exceptions -from . import env -from . import fmt -from . import plugins -from . import serialize -from . import utils +from . import env, exceptions, fmt, plugins, serialize, utils +from .types import Config, cast_config -def update(root: str) -> Dict[str, Any]: +def update(root: str) -> Config: """ Load and save the configuration. """ @@ -19,7 +15,7 @@ def update(root: str) -> Dict[str, Any]: return config -def load(root: str) -> Dict[str, Any]: +def load(root: str) -> Config: """ Load full configuration. This will raise an exception if there is no current configuration in the project root. @@ -28,13 +24,13 @@ def load(root: str) -> Dict[str, Any]: return load_no_check(root) -def load_no_check(root: str) -> Dict[str, Any]: +def load_no_check(root: str) -> Config: config, defaults = load_all(root) merge(config, defaults) return config -def load_all(root: str) -> Tuple[Dict[str, Any], Dict[str, Any]]: +def load_all(root: str) -> Tuple[Config, Config]: """ Return: current (dict): params currently saved in config.yml @@ -46,9 +42,7 @@ def load_all(root: str) -> Tuple[Dict[str, Any], Dict[str, Any]]: return current, defaults -def merge( - config: Dict[str, str], defaults: Dict[str, str], force: bool = False -) -> None: +def merge(config: Config, defaults: Config, force: bool = False) -> None: """ Merge default values with user configuration and perform rendering of "{{...}}" values. @@ -58,23 +52,18 @@ def merge( config[key] = env.render_unknown(config, value) -def load_defaults() -> Dict[str, Any]: +def load_defaults() -> Config: config = serialize.load(env.read_template_file("config.yml")) - return cast(Dict[str, Any], config) + return cast_config(config) -def load_config_file(path: str) -> Dict[str, Any]: +def load_config_file(path: str) -> Config: with open(path) as f: config = serialize.load(f.read()) - if not isinstance(config, dict): - raise exceptions.TutorError( - "Invalid configuration: expected dict, got {}".format(config.__class__) - ) - - return config + return cast_config(config) -def load_current(root: str, defaults: Dict[str, str]) -> Dict[str, Any]: +def load_current(root: str, defaults: Config) -> Config: """ Load the configuration currently stored on disk. Note: this modifies the defaults with the plugin default values. @@ -87,7 +76,7 @@ def load_current(root: str, defaults: Dict[str, str]) -> Dict[str, Any]: return config -def load_user(root: str) -> Dict[str, Any]: +def load_user(root: str) -> Config: path = config_path(root) if not os.path.exists(path): return {} @@ -97,14 +86,14 @@ def load_user(root: str) -> Dict[str, Any]: return config -def load_env(config: Dict[str, str], defaults: Dict[str, str]) -> None: +def load_env(config: Config, defaults: Config) -> None: for k in defaults.keys(): env_var = "TUTOR_" + k if env_var in os.environ: config[k] = serialize.parse(os.environ[env_var]) -def load_required(config: Dict[str, str], defaults: Dict[str, str]) -> None: +def load_required(config: Config, defaults: Config) -> None: """ All these keys must be present in the user's config.yml. This includes all values that are generated once and must be kept after that, such as passwords. @@ -121,7 +110,7 @@ def load_required(config: Dict[str, str], defaults: Dict[str, str]) -> None: config[key] = env.render_unknown(config, defaults[key]) -def load_plugins(config: Dict[str, str], defaults: Dict[str, str]) -> None: +def load_plugins(config: Config, defaults: Config) -> None: """ Add, override and set new defaults from plugins. """ @@ -143,11 +132,11 @@ def load_plugins(config: Dict[str, str], defaults: Dict[str, str]) -> None: config[key] = env.render_unknown(config, value) -def is_service_activated(config: Dict[str, Any], service: str) -> bool: +def is_service_activated(config: Config, service: str) -> bool: return config["RUN_" + service.upper()] is not False -def upgrade_obsolete(config: Dict[str, Any]) -> None: +def upgrade_obsolete(config: Config) -> None: # Openedx-specific mysql passwords if "MYSQL_PASSWORD" in config: config["MYSQL_ROOT_PASSWORD"] = config["MYSQL_PASSWORD"] @@ -209,7 +198,7 @@ def convert_json2yml(root: str) -> None: ) -def save_config_file(root: str, config: Dict[str, str]) -> None: +def save_config_file(root: str, config: Config) -> None: path = config_path(root) utils.ensure_file_directory_exists(path) with open(path, "w") as of: diff --git a/tutor/env.py b/tutor/env.py index c990405..258a857 100644 --- a/tutor/env.py +++ b/tutor/env.py @@ -1,17 +1,14 @@ import codecs -from copy import deepcopy import os -from typing import Dict, Any, Iterable, List, Optional, Type, Union +from copy import deepcopy +from typing import Any, Iterable, List, Optional, Type, Union import jinja2 import pkg_resources -from . import exceptions -from . import fmt -from . import plugins -from . import utils +from . import exceptions, fmt, plugins, utils from .__about__ import __version__ - +from .types import Config TEMPLATES_ROOT = pkg_resources.resource_filename("tutor", "templates") VERSION_FILENAME = "version" @@ -20,7 +17,7 @@ BIN_FILE_EXTENSIONS = [".ico", ".jpg", ".png", ".ttf", ".woff", ".woff2"] class Renderer: @classmethod - def instance(cls: Type["Renderer"], config: Dict[str, Any]) -> "Renderer": + def instance(cls: Type["Renderer"], config: Config) -> "Renderer": # Load template roots: these are required to be able to use # {% include .. %} directives template_roots = [TEMPLATES_ROOT] @@ -32,7 +29,7 @@ class Renderer: def __init__( self, - config: Dict[str, Any], + config: Config, template_roots: List[str], ignore_folders: Optional[List[str]] = None, ): @@ -64,7 +61,10 @@ class Renderer: The elements of `prefix` must contain only "/", and not os.sep. """ full_prefix = "/".join(prefix) - for template in self.environment.loader.list_templates(): # type: ignore + env_templates: List[ + str + ] = self.environment.loader.list_templates() # type:ignore[no-untyped-call] + for template in env_templates: if template.startswith(full_prefix) and self.is_part_of_env(template): yield template @@ -171,7 +171,7 @@ class Renderer: ) -def save(root: str, config: Dict[str, Any]) -> None: +def save(root: str, config: Config) -> None: """ Save the full environment, including version information. """ @@ -206,7 +206,7 @@ def upgrade_obsolete(root: str) -> None: def save_plugin_templates( - plugin: plugins.BasePlugin, root: str, config: Dict[str, Any] + plugin: plugins.BasePlugin, root: str, config: Config ) -> None: """ Save plugin templates to plugins//*. @@ -218,7 +218,7 @@ def save_plugin_templates( save_all_from(subdir_path, plugins_root, config) -def save_all_from(prefix: str, root: str, config: Dict[str, Any]) -> None: +def save_all_from(prefix: str, root: str, config: Config) -> None: """ Render the templates that start with `prefix` and store them with the same hierarchy at `root`. Here, `prefix` can be the result of os.path.join(...). @@ -240,7 +240,7 @@ def write_to(content: Union[str, bytes], path: str) -> None: of_text.write(content) -def render_file(config: Dict[str, Any], *path: str) -> Union[str, bytes]: +def render_file(config: Config, *path: str) -> Union[str, bytes]: """ Return the rendered contents of a template. """ @@ -249,7 +249,7 @@ def render_file(config: Dict[str, Any], *path: str) -> Union[str, bytes]: return renderer.render_template(template_name) -def render_dict(config: Dict[str, Any]) -> None: +def render_dict(config: Config) -> None: """ Render the values from the dict. This is useful for rendering the default values from config.yml. @@ -257,7 +257,7 @@ def render_dict(config: Dict[str, Any]) -> None: Args: config (dict) """ - rendered = {} + rendered: Config = {} for key, value in config.items(): if isinstance(value, str): rendered[key] = render_str(config, value) @@ -267,13 +267,13 @@ def render_dict(config: Dict[str, Any]) -> None: config[k] = v -def render_unknown(config: Dict[str, Any], value: Any) -> Any: +def render_unknown(config: Config, value: Any) -> Any: if isinstance(value, str): return render_str(config, value) return value -def render_str(config: Dict[str, Any], text: str) -> str: +def render_str(config: Config, text: str) -> str: """ Args: text (str) diff --git a/tutor/images.py b/tutor/images.py index 19f0aa3..1939d3a 100644 --- a/tutor/images.py +++ b/tutor/images.py @@ -1,11 +1,10 @@ -from typing import Any, Dict - -from . import fmt -from . import utils +from . import fmt, utils +from .types import Config, get_typed -def get_tag(config: Dict[str, Any], name: str) -> Any: - return config["DOCKER_IMAGE_" + name.upper().replace("-", "_")] +def get_tag(config: Config, name: str) -> str: + key = "DOCKER_IMAGE_" + name.upper().replace("-", "_") + return get_typed(config, key, str) def build(path: str, tag: str, *args: str) -> None: diff --git a/tutor/interactive.py b/tutor/interactive.py index 79681e5..048e10d 100644 --- a/tutor/interactive.py +++ b/tutor/interactive.py @@ -1,14 +1,14 @@ -from typing import Any, Dict, List, Tuple +from typing import List, Tuple + import click from . import config as tutor_config -from . import env -from . import exceptions -from . import fmt +from . import env, exceptions, fmt from .__about__ import __version__ +from .types import Config, get_typed -def update(root: str, interactive: bool = True) -> Dict[str, Any]: +def update(root: str, interactive: bool = True) -> Config: """ Load and save the configuration. """ @@ -18,9 +18,7 @@ def update(root: str, interactive: bool = True) -> Dict[str, Any]: return config -def load_all( - root: str, interactive: bool = True -) -> Tuple[Dict[str, Any], Dict[str, Any]]: +def load_all(root: str, interactive: bool = True) -> Tuple[Config, Config]: """ Load configuration and interactively ask questions to collect param values from the user. """ @@ -30,7 +28,7 @@ def load_all( return config, defaults -def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None: +def ask_questions(config: Config, defaults: Config) -> None: run_for_prod = config.get("LMS_HOST") != "local.overhang.io" run_for_prod = click.confirm( fmt.question( @@ -40,7 +38,7 @@ def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None: default=run_for_prod, ) if not run_for_prod: - dev_values = { + dev_values: Config = { "LMS_HOST": "local.overhang.io", "CMS_HOST": "studio.local.overhang.io", "ENABLE_HTTPS": False, @@ -54,7 +52,8 @@ def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None: if run_for_prod: ask("Your website domain name for students (LMS)", "LMS_HOST", config, defaults) - if "localhost" in config["LMS_HOST"]: + lms_host = get_typed(config, "LMS_HOST", str) + if "localhost" in lms_host: raise exceptions.TutorError( "You may not use 'localhost' as the LMS domain name. To run a local platform for testing purposes you should answer 'n' to the previous question." ) @@ -159,19 +158,18 @@ def ask_questions(config: Dict[str, Any], defaults: Dict[str, Any]) -> None: ) -def ask( - question: str, key: str, config: Dict[str, Any], defaults: Dict[str, Any] -) -> None: - default = env.render_str(config, config.get(key, defaults[key])) +def ask(question: str, key: str, config: Config, defaults: Config) -> None: + default = get_typed(defaults, key, str) + default = get_typed(config, key, str, default=default) + default = env.render_str(config, default) config[key] = click.prompt( fmt.question(question), prompt_suffix=" ", default=default, show_default=True ) -def ask_bool( - question: str, key: str, config: Dict[str, Any], defaults: Dict[str, Any] -) -> None: - default = config.get(key, defaults[key]) +def ask_bool(question: str, key: str, config: Config, defaults: Config) -> None: + default = get_typed(defaults, key, bool) + default = get_typed(config, key, bool, default=default) config[key] = click.confirm( fmt.question(question), prompt_suffix=" ", default=default ) @@ -180,11 +178,11 @@ def ask_bool( def ask_choice( question: str, key: str, - config: Dict[str, Any], - defaults: Dict[str, Any], + config: Config, + defaults: Config, choices: List[str], ) -> None: - default = config.get(key, defaults[key]) + default = str(config.get(key, defaults[key])) answer = click.prompt( fmt.question(question), type=click.Choice(choices), diff --git a/tutor/jobs.py b/tutor/jobs.py index 4910718..1a3e49d 100644 --- a/tutor/jobs.py +++ b/tutor/jobs.py @@ -1,8 +1,7 @@ -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple, Union -from . import env -from . import fmt -from . import plugins +from . import env, fmt, plugins +from .types import Config BASE_OPENEDX_COMMAND = """ export DJANGO_SETTINGS_MODULE=$SERVICE_VARIANT.envs.$SETTINGS @@ -11,7 +10,7 @@ echo "Loading settings $DJANGO_SETTINGS_MODULE" class BaseJobRunner: - def __init__(self, root: str, config: Dict[str, Any]): + def __init__(self, root: str, config: Config): self.root = root self.config = config diff --git a/tutor/plugins.py b/tutor/plugins.py index ee0a305..8bc1170 100644 --- a/tutor/plugins.py +++ b/tutor/plugins.py @@ -1,18 +1,15 @@ -from collections import namedtuple -from copy import deepcopy -from glob import glob import importlib import os +from copy import deepcopy +from glob import glob from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union import appdirs import click import pkg_resources -from . import exceptions -from . import fmt -from . import serialize - +from . import exceptions, fmt, serialize +from .types import Config, get_typed CONFIG_KEY = "PLUGINS" @@ -69,7 +66,7 @@ class BasePlugin: self.command: click.Command = command @staticmethod - def load_config(obj: Any, plugin_name: str) -> Dict[str, Dict[str, Any]]: + def load_config(obj: Any, plugin_name: str) -> Dict[str, Config]: """ Load config and check types. """ @@ -181,15 +178,15 @@ class BasePlugin: return self.name.upper() + "_" + key @property - def config_add(self) -> Dict[str, Any]: + def config_add(self) -> Config: return self.config.get("add", {}) @property - def config_set(self) -> Dict[str, Any]: + def config_set(self) -> Config: return self.config.get("set", {}) @property - def config_defaults(self) -> Dict[str, Any]: + def config_defaults(self) -> Config: return self.config.get("defaults", {}) @property @@ -269,16 +266,32 @@ class DictPlugin(BasePlugin): os.environ.get(ROOT_ENV_VAR_NAME, "") ) or appdirs.user_data_dir(appname="tutor-plugins") - def __init__(self, data: Dict[str, Any]): - Module = namedtuple("Module", data.keys()) # type: ignore - obj = Module(**data) # type: ignore - super().__init__(data["name"], obj) - self._version = data["version"] + def __init__(self, data: Config): + name = data["name"] + if not isinstance(name, str): + raise exceptions.TutorError( + "Invalid plugin name: '{}'. Expected str, got {}".format( + name, name.__class__ + ) + ) + + # Create a generic object (sort of a named tuple) which will contain all key/values from data + class Module: + pass + + obj = Module() + for key, value in data.items(): + setattr(obj, key, value) + + super().__init__(name, obj) + + version = data["version"] + if not isinstance(version, str): + raise TypeError("DictPlugin.__version__ must be str") + self._version: str = version @property def version(self) -> str: - if not isinstance(self._version, str): - raise TypeError("DictPlugin.__version__ must be str") return self._version @classmethod @@ -305,7 +318,7 @@ class Plugins: DictPlugin, ] - def __init__(self, config: Dict[str, Any]): + def __init__(self, config: Config): self.config = deepcopy(config) # patches has the following structure: # {patch_name -> {plugin_name -> "content"}} @@ -380,18 +393,17 @@ def iter_installed() -> Iterator[BasePlugin]: yield from Plugins.iter_installed() -def enable(config: Dict[str, Any], name: str) -> None: +def enable(config: Config, name: str) -> None: if not is_installed(name): raise exceptions.TutorError("plugin '{}' is not installed.".format(name)) if is_enabled(config, name): return - if CONFIG_KEY not in config: - config[CONFIG_KEY] = [] - config[CONFIG_KEY].append(name) - config[CONFIG_KEY].sort() + enabled = enabled_plugins(config) + enabled.append(name) + enabled.sort() -def disable(config: Dict[str, Any], name: str) -> None: +def disable(config: Config, name: str) -> None: fmt.echo_info("Disabling plugin {}...".format(name)) for plugin in Plugins(config).iter_enabled(): if name == plugin.name: @@ -400,25 +412,32 @@ def disable(config: Dict[str, Any], name: str) -> None: config.pop(key, None) fmt.echo_info(" Removed config entry {}={}".format(key, value)) # Remove plugin from list - while name in config[CONFIG_KEY]: - config[CONFIG_KEY].remove(name) + enabled = enabled_plugins(config) + while name in enabled: + enabled.remove(name) fmt.echo_info(" Plugin disabled") -def iter_enabled(config: Dict[str, Any]) -> Iterator[BasePlugin]: +def iter_enabled(config: Config) -> Iterator[BasePlugin]: yield from Plugins(config).iter_enabled() -def is_enabled(config: Dict[str, Any], name: str) -> bool: - plugin_list = config.get(CONFIG_KEY) or [] - return name in plugin_list +def is_enabled(config: Config, name: str) -> bool: + return name in enabled_plugins(config) -def iter_patches(config: Dict[str, str], name: str) -> Iterator[Tuple[str, str]]: +def enabled_plugins(config: Config) -> List[str]: + if not config.get(CONFIG_KEY): + config[CONFIG_KEY] = [] + plugins = get_typed(config, CONFIG_KEY, list) + return plugins + + +def iter_patches(config: Config, name: str) -> Iterator[Tuple[str, str]]: yield from Plugins(config).iter_patches(name) def iter_hooks( - config: Dict[str, Any], hook_name: str + config: Config, hook_name: str ) -> Iterator[Tuple[str, Union[Dict[str, str], List[str]]]]: yield from Plugins(config).iter_hooks(hook_name) diff --git a/tutor/serialize.py b/tutor/serialize.py index 3831d93..1dbbf95 100644 --- a/tutor/serialize.py +++ b/tutor/serialize.py @@ -1,13 +1,12 @@ import re -from typing import Any, IO, Iterator, Tuple, Union +from typing import IO, Any, Iterator, Tuple, Union +import click import yaml from _io import TextIOWrapper from yaml.parser import ParserError from yaml.scanner import ScannerError -import click - def load(stream: Union[str, IO[str]]) -> Any: return yaml.load(stream, Loader=yaml.SafeLoader) @@ -21,6 +20,12 @@ def dump(content: Any, fileobj: TextIOWrapper) -> None: yaml.dump(content, stream=fileobj, default_flow_style=False) +def dumps(content: Any) -> str: + result = yaml.dump(content, default_flow_style=False) + assert isinstance(result, str) + return result + + def parse(v: Union[str, IO[str]]) -> Any: """ Parse a yaml-formatted string. diff --git a/tutor/types.py b/tutor/types.py new file mode 100644 index 0000000..70193d8 --- /dev/null +++ b/tutor/types.py @@ -0,0 +1,39 @@ +from typing import Any, Dict, List, Optional, Type, TypeVar, Union + +from . import exceptions + +ConfigValue = Union[ + str, float, None, bool, List[str], List[Any], Dict[str, Any], Dict[Any, Any] +] +Config = Dict[str, ConfigValue] + + +def cast_config(config: Any) -> Config: + if not isinstance(config, dict): + raise exceptions.TutorError( + "Invalid configuration: expected dict, got {}".format(config.__class__) + ) + for key in config.keys(): + if not isinstance(key, str): + raise exceptions.TutorError( + "Invalid configuration: expected str, got {} for key '{}'".format( + key.__class__, key + ) + ) + return config + + +T = TypeVar("T") + + +def get_typed( + config: Config, key: str, expected_type: Type[T], default: Optional[T] = None +) -> T: + value = config.get(key, default) + if not isinstance(value, expected_type): + raise exceptions.TutorError( + "Invalid config entry: expected {}, got {} for key '{}'".format( + expected_type.__name__, value.__class__, key + ) + ) + return value