diff --git a/tutor/commands/webui.py b/tutor/commands/webui.py index b210bd7..2d8b4da 100644 --- a/tutor/commands/webui.py +++ b/tutor/commands/webui.py @@ -4,7 +4,7 @@ import platform import subprocess import sys import tarfile -from typing import Any, Dict +from typing import Any, Dict, Optional from urllib.request import urlopen import click @@ -13,6 +13,7 @@ import click # the web ui can be launched even where there is no configuration. from .. import fmt from .. import env as tutor_env +from .. import exceptions from .. import serialize from .context import Context @@ -121,12 +122,17 @@ def check_gotty_binary(root: str) -> None: compressed.extract("./gotty", dirname) -def load_config(root: str) -> Dict[str, Any]: +def load_config(root: str) -> Dict[str, Optional[str]]: path = config_path(root) if not os.path.exists(path): save_webui_config_file(root, {"user": None, "password": None}) with open(config_path(root)) as f: - return serialize.load(f) + config = serialize.load(f) + if not isinstance(config, dict): + raise exceptions.TutorError( + "Invalid webui: expected dict, got {}".format(config.__class__) + ) + return config def save_webui_config_file(root: str, config: Dict[str, Any]) -> None: diff --git a/tutor/config.py b/tutor/config.py index 7e9def7..63d4e6c 100644 --- a/tutor/config.py +++ b/tutor/config.py @@ -1,5 +1,5 @@ import os -from typing import Dict, Any, Tuple +from typing import cast, Dict, Any, Tuple from . import exceptions from . import env @@ -59,12 +59,19 @@ def merge( def load_defaults() -> Dict[str, Any]: - return serialize.load(env.read_template_file("config.yml")) + config = serialize.load(env.read_template_file("config.yml")) + return cast(Dict[str, Any], config) def load_config_file(path: str) -> Dict[str, Any]: with open(path) as f: - return serialize.load(f.read()) + config = serialize.load(f.read()) + if not isinstance(config, dict): + raise exceptions.TutorError( + "Invalid configuration: expected dict, got {}".format(config.__class__) + ) + + return config def load_current(root: str, defaults: Dict[str, str]) -> Dict[str, Any]: diff --git a/tutor/serialize.py b/tutor/serialize.py index 9b29d9c..3831d93 100644 --- a/tutor/serialize.py +++ b/tutor/serialize.py @@ -1,5 +1,5 @@ import re -from typing import cast, Any, Dict, IO, Iterator, Tuple, Union +from typing import Any, IO, Iterator, Tuple, Union import yaml from _io import TextIOWrapper @@ -9,15 +9,15 @@ from yaml.scanner import ScannerError import click -def load(stream: Union[str, IO[str]]) -> Dict[str, str]: - return cast(Dict[str, str], yaml.load(stream, Loader=yaml.SafeLoader)) +def load(stream: Union[str, IO[str]]) -> Any: + return yaml.load(stream, Loader=yaml.SafeLoader) def load_all(stream: str) -> Iterator[Any]: return yaml.load_all(stream, Loader=yaml.SafeLoader) -def dump(content: Dict[str, str], fileobj: TextIOWrapper) -> None: +def dump(content: Any, fileobj: TextIOWrapper) -> None: yaml.dump(content, stream=fileobj, default_flow_style=False)