diff --git a/bin/main.py b/bin/main.py index 6e70055..7f3ce22 100755 --- a/bin/main.py +++ b/bin/main.py @@ -12,7 +12,7 @@ for plugin_name in [ "xqueue", ]: try: - OfficialPlugin.INSTALLED.append(OfficialPlugin(plugin_name)) + OfficialPlugin.load(plugin_name) except ImportError: pass diff --git a/tests/test_plugins.py b/tests/test_plugins.py index c874fe1..889b5f3 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -20,12 +20,11 @@ class PluginsTests(unittest.TestCase): self.assertFalse(plugins.is_installed("dummy")) @patch.object(plugins.DictPlugin, "iter_installed", return_value=[]) - def test_extra_installed(self, _dict_plugin_iter_installed): - plugin1 = plugins.BasePlugin("plugin1", None) - plugin2 = plugins.BasePlugin("plugin2", None) - - plugins.OfficialPlugin.INSTALLED.append(plugin1) - plugins.OfficialPlugin.INSTALLED.append(plugin2) + def test_official_plugins(self, _dict_plugin_iter_installed): + with patch.object(plugins.importlib, "import_module", return_value=42): + plugin1 = plugins.OfficialPlugin.load("plugin1") + with patch.object(plugins.importlib, "import_module", return_value=43): + plugin2 = plugins.OfficialPlugin.load("plugin2") with patch.object( plugins.EntrypointPlugin, "iter_installed", return_value=[plugin1], ): diff --git a/tutor/env.py b/tutor/env.py index b137819..d364ccd 100644 --- a/tutor/env.py +++ b/tutor/env.py @@ -18,20 +18,16 @@ BIN_FILE_EXTENSIONS = [".ico", ".jpg", ".png", ".ttf", ".woff", ".woff2"] class Renderer: - INSTANCE = None - @classmethod def instance(cls, config): - if cls.INSTANCE is None or cls.INSTANCE.config != config: - # Load template roots: these are required to be able to use - # {% include .. %} directives - template_roots = [TEMPLATES_ROOT] - for plugin in plugins.iter_enabled(config): - if plugin.templates_root: - template_roots.append(plugin.templates_root) + # Load template roots: these are required to be able to use + # {% include .. %} directives + template_roots = [TEMPLATES_ROOT] + for plugin in plugins.iter_enabled(config): + if plugin.templates_root: + template_roots.append(plugin.templates_root) - cls.INSTANCE = cls(config, template_roots, ignore_folders=["partials"]) - return cls.INSTANCE + return cls(config, template_roots, ignore_folders=["partials"]) @classmethod def reset(cls): diff --git a/tutor/plugins.py b/tutor/plugins.py index f2ba978..52e5894 100644 --- a/tutor/plugins.py +++ b/tutor/plugins.py @@ -47,6 +47,9 @@ class BasePlugin: `command` (click.Command): if a plugin exposes a `command` attribute, users will be able to run it from the command line as `tutor pluginname`. """ + INSTALLED = [] + _IS_LOADED = False + def __init__(self, name, obj): self.name = name self.config = get_callable_attr(obj, "config", {}) @@ -79,6 +82,14 @@ class BasePlugin: @classmethod def iter_installed(cls): + if not cls._IS_LOADED: + for plugin in cls.iter_load(): + cls.INSTALLED.append(plugin) + cls._IS_LOADED = True + yield from cls.INSTALLED + + @classmethod + def iter_load(cls): raise NotImplementedError @@ -101,19 +112,22 @@ class EntrypointPlugin(BasePlugin): return self.entrypoint.dist.version @classmethod - def iter_installed(cls): + def iter_load(cls): for entrypoint in pkg_resources.iter_entry_points(cls.ENTRYPOINT): yield cls(entrypoint) class OfficialPlugin(BasePlugin): """ - Official plugins have a "plugin" module which exposes a __version__ - attribute. - Official plugins should be manually added to INSTALLED. + Official plugins have a "plugin" module which exposes a __version__ attribute. + Official plugins should be manually added by calling `OfficialPlugin.load()`. """ - INSTALLED = [] + @classmethod + def load(cls, name): + plugin = cls(name) + cls.INSTALLED.append(plugin) + return plugin def __init__(self, name): self.module = importlib.import_module("tutor{}.plugin".format(name)) @@ -124,8 +138,8 @@ class OfficialPlugin(BasePlugin): return self.module.__version__ @classmethod - def iter_installed(cls): - yield from cls.INSTALLED + def iter_load(cls): + yield from [] class DictPlugin(BasePlugin): @@ -145,7 +159,7 @@ class DictPlugin(BasePlugin): return self._version @classmethod - def iter_installed(cls): + def iter_load(cls): for path in glob(os.path.join(cls.ROOT, "*.yml")): with open(path) as f: data = serialize.load(f) @@ -162,8 +176,7 @@ class DictPlugin(BasePlugin): class Plugins: - - INSTANCE = None + PLUGIN_CLASSES = [OfficialPlugin, EntrypointPlugin, DictPlugin] def __init__(self, config): self.config = deepcopy(config) @@ -184,23 +197,17 @@ class Plugins: @classmethod def clear(cls): - cls.INSTANCE = None - OfficialPlugin.INSTALLED.clear() - - @classmethod - def instance(cls, config): - if cls.INSTANCE is None or cls.INSTANCE.config != config: - cls.INSTANCE = cls(config) - return cls.INSTANCE + for PluginClass in cls.PLUGIN_CLASSES: + PluginClass.INSTALLED.clear() @classmethod def iter_installed(cls): """ - Iterate on all installed plugins. Plugins are deduplicated by name. + Iterate on all installed plugins. Plugins are deduplicated by name. The list of installed plugins is cached to + prevent too many re-computations, which happens a lot. """ - classes = [OfficialPlugin, EntrypointPlugin, DictPlugin] installed_plugin_names = set() - for PluginClass in classes: + for PluginClass in cls.PLUGIN_CLASSES: for plugin in PluginClass.iter_installed(): if plugin.name not in installed_plugin_names: installed_plugin_names.add(plugin.name) @@ -252,7 +259,7 @@ def enable(config, name): def disable(config, name): fmt.echo_info("Disabling plugin {}...".format(name)) - for plugin in Plugins.instance(config).iter_enabled(): + for plugin in Plugins(config).iter_enabled(): if name == plugin.name: # Remove "set" config entries for key, value in plugin.config_set.items(): @@ -265,7 +272,7 @@ def disable(config, name): def iter_enabled(config): - yield from Plugins.instance(config).iter_enabled() + yield from Plugins(config).iter_enabled() def is_enabled(config, name): @@ -273,8 +280,8 @@ def is_enabled(config, name): def iter_patches(config, name): - yield from Plugins.instance(config).iter_patches(name) + yield from Plugins(config).iter_patches(name) def iter_hooks(config, hook_name): - yield from Plugins.instance(config).iter_hooks(hook_name) + yield from Plugins(config).iter_hooks(hook_name)