diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..a55400e --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,53 @@ +import unittest +import unittest.mock +import tempfile + +from tutor.commands import config as tutor_config +from tutor import env + + +class ConfigTests(unittest.TestCase): + def setUp(self): + # This is necessary to avoid cached mocks + env.Renderer.reset() + + def test_merge(self): + config = {} + defaults = tutor_config.load_defaults() + with unittest.mock.patch.object( + tutor_config.utils, "random_string", return_value="abcd" + ): + tutor_config.merge(config, defaults) + + self.assertEqual("abcd", config["MYSQL_ROOT_PASSWORD"]) + + def test_save_twice(self): + with tempfile.TemporaryDirectory() as root: + tutor_config.save(root, silent=True) + config1 = tutor_config.load_user(root) + + tutor_config.save(root, silent=True) + config2 = tutor_config.load_user(root) + + self.assertEqual(config1, config2) + + def test_removed_entry_is_added_on_save(self): + with tempfile.TemporaryDirectory() as root: + with unittest.mock.patch.object( + tutor_config.utils, "random_string" + ) as mock_random_string: + mock_random_string.return_value = "abcd" + defaults = tutor_config.load_defaults() + config1 = tutor_config.load_current(root, defaults) + password1 = config1["MYSQL_ROOT_PASSWORD"] + + config1.pop("MYSQL_ROOT_PASSWORD") + tutor_config.save_config(root, config1) + + mock_random_string.return_value = "efgh" + defaults = tutor_config.load_defaults() + config2 = tutor_config.load_current(root, defaults) + password2 = config2["MYSQL_ROOT_PASSWORD"] + + self.assertEqual("abcd", password1) + self.assertEqual("efgh", password2) diff --git a/tests/test_env.py b/tests/test_env.py index 8344d7d..7dba871 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -1,5 +1,7 @@ +import tempfile import unittest +from tutor.commands import config as tutor_config from tutor import env from tutor import exceptions @@ -22,3 +24,15 @@ class EnvTests(unittest.TestCase): def test_render_str_missing_configuration(self): self.assertRaises(exceptions.TutorError, env.render_str, {}, "hello {{ name }}") + + def test_render_file(self): + config = {} + tutor_config.merge(config, tutor_config.load_defaults()) + config["MYSQL_ROOT_PASSWORD"] = "testpassword" + rendered = env.render_file(config, "scripts", "create_databases.sh") + self.assertIn("testpassword", rendered) + + def test_render_file_missing_configuration(self): + self.assertRaises( + exceptions.TutorError, env.render_file, {}, "local", "docker-compose.yml" + ) diff --git a/tests/test_scripts.py b/tests/test_scripts.py index 32cf647..dd0705c 100644 --- a/tests/test_scripts.py +++ b/tests/test_scripts.py @@ -1,7 +1,7 @@ import unittest import unittest.mock -from tutor.commands.config import load_defaults +from tutor.commands import config as tutor_config from tutor import env from tutor import scripts @@ -9,12 +9,14 @@ from tutor import scripts class ScriptsTests(unittest.TestCase): def test_run_script(self): config = {} - load_defaults(config) - rendered_script = env.render_file(config, "scripts", "create_databases.sh") - with unittest.mock.Mock() as run_func: - scripts.run_script( - "/tmp", config, "someservice", "create_databases.sh", run_func - ) - run_func.assert_called_once_with( - "/tmp", config, "someservice", rendered_script - ) + defaults = tutor_config.load_defaults() + tutor_config.merge(config, defaults) + + rendered_script = env.render_file( + config, "scripts", "create_databases.sh" + ).strip() + run_func = unittest.mock.Mock() + scripts.run_script( + "/tmp", config, "someservice", "create_databases.sh", run_func + ) + run_func.assert_called_once_with("/tmp", "someservice", rendered_script) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1dd2b15..2e91a3c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,6 +18,9 @@ class UtilsTests(unittest.TestCase): "domain.com", utils.common_domain("sub.domain.com", "ub.domain.com") ) + def test_reverse_host(self): + self.assertEqual("com.google.www", utils.reverse_host("www.google.com")) + class SerializeTests(unittest.TestCase): def test_parse_value(self): diff --git a/tutor/commands/config.py b/tutor/commands/config.py index 15694b2..78cc276 100644 --- a/tutor/commands/config.py +++ b/tutor/commands/config.py @@ -33,14 +33,14 @@ def save_command(root, silent1, silent2, set_): def save(root, silent=False, keyvalues=None): keyvalues = keyvalues or [] - config = load_current(root) + defaults = load_defaults() + config = load_current(root, defaults) for k, v in keyvalues: config[k] = v if not silent: - load_interactive(config) + load_interactive(config, defaults) save_config(root, config) - - load_defaults(config) + merge(config, defaults) save_env(root, config) @@ -54,8 +54,9 @@ def printroot(root): @opts.root @click.argument("key") def printvalue(root, key): - config = load_current(root) - load_defaults(config) + defaults = load_defaults() + config = load_current(root, defaults) + merge(config, defaults) try: print(config[key]) except KeyError: @@ -67,15 +68,15 @@ def load(root): Load configuration, and generate it interactively if the file does not exist. """ - config = load_current(root) + defaults = load_defaults() + config = load_current(root, defaults) should_update_env = False if not os.path.exists(config_path(root)): - load_interactive(config) + load_interactive(config, defaults) should_update_env = True save_config(root, config) - load_defaults(config) if not env.is_up_to_date(root): should_update_env = True pre_upgrade_announcement(root) @@ -120,40 +121,55 @@ def pre_upgrade_announcement(root): ) -def load_current(root): +def load_current(root, defaults): convert_json2yml(root) - config = {} - load_base(config) - load_user(config, root) - load_env(config) + config = load_user(root) + load_env(config, defaults) + load_required(config, defaults) return config -def load_base(config): - base = serialize.load(env.read("config-base.yml")) - for k, v in base.items(): - config[k] = v +def load_user(root): + path = config_path(root) + if not os.path.exists(path): + return {} + + with open(path) as fi: + config = serialize.load(fi.read()) + upgrade_obsolete(config) + return config -def load_env(config): - base_config = serialize.load(env.read("config-base.yml")) - default_config = serialize.load(env.read("config-defaults.yml")) - keys = set(list(base_config.keys()) + list(default_config.keys())) - - for k in keys: +def load_env(config, defaults): + for k in defaults.keys(): env_var = "TUTOR_" + k if env_var in os.environ: config[k] = serialize.parse_value(os.environ[env_var]) -def load_user(config, root): - path = config_path(root) - if os.path.exists(path): - with open(path) as fi: - loaded = serialize.load(fi.read()) - for key, value in loaded.items(): - config[key] = value - upgrade_obsolete(config) +def load_required(config, defaults): + """ + All these keys must be present in the user's config.yml. This includes all important + values, such as LMS_HOST, and randomly-generated values, such as passwords. + """ + for key in [ + "LMS_HOST", + "CMS_HOST", + "CONTACT_EMAIL", + "SECRET_KEY", + "MYSQL_ROOT_PASSWORD", + "OPENEDX_MYSQL_PASSWORD", + "NOTES_MYSQL_PASSWORD", + "NOTES_SECRET_KEY", + "NOTES_OAUTH2_SECRET", + "XQUEUE_AUTH_PASSWORD", + "XQUEUE_MYSQL_PASSWORD", + "XQUEUE_SECRET_KEY", + "ANDROID_OAUTH2_SECRET", + "ID", + ]: + if key not in config: + config[key] = env.render_str(config, defaults[key]) def upgrade_obsolete(config): @@ -168,15 +184,16 @@ def upgrade_obsolete(config): config["OPENEDX_MYSQL_USERNAME"] = config.pop("MYSQL_USERNAME") -def load_interactive(config): - ask("Your website domain name for students (LMS)", "LMS_HOST", config) - ask("Your website domain name for teachers (CMS)", "CMS_HOST", config) - ask("Your platform name/title", "PLATFORM_NAME", config) - ask("Your public contact email address", "CONTACT_EMAIL", config) +def load_interactive(config, defaults): + ask("Your website domain name for students (LMS)", "LMS_HOST", config, defaults) + ask("Your website domain name for teachers (CMS)", "CMS_HOST", config, defaults) + ask("Your platform name/title", "PLATFORM_NAME", config, defaults) + ask("Your public contact email address", "CONTACT_EMAIL", config, defaults) ask_choice( "The default language code for the platform", "LANGUAGE_CODE", config, + defaults, [ "en", "am", @@ -264,47 +281,40 @@ def load_interactive(config): ), "ACTIVATE_HTTPS", config, + defaults, ) ask_bool( "Activate Student Notes service (https://open.edx.org/features/student-notes)?", "ACTIVATE_NOTES", config, + defaults, ) ask_bool( "Activate Xqueue for external grader services (https://github.com/edx/xqueue)?", "ACTIVATE_XQUEUE", config, + defaults, ) -def load_defaults(config): - defaults = serialize.load(env.read("config-defaults.yml")) - for k, v in defaults.items(): - if k not in config: - config[k] = v - - # Add extra configuration parameters that need to be computed separately - config["lms_cms_common_domain"] = utils.common_domain( - config["LMS_HOST"], config["CMS_HOST"] - ) - config["lms_host_reverse"] = ".".join(config["LMS_HOST"].split(".")[::-1]) +def load_defaults(): + return serialize.load(env.read("config.yml")) -def ask(question, key, config): - default = env.render_str(config, config[key]) +def ask(question, key, config, defaults): + default = env.render_str(config, config.get(key, defaults[key])) config[key] = click.prompt( fmt.question(question), prompt_suffix=" ", default=default, show_default=True ) -def ask_bool(question, key, config): - config[key] = click.confirm( - fmt.question(question), prompt_suffix=" ", default=config[key] - ) +def ask_bool(question, key, config, defaults): + default = config.get(key, defaults[key]) + config[key] = click.confirm(fmt.question(question), prompt_suffix=" ", default=default) -def ask_choice(question, key, config, choices): - default = config[key] +def ask_choice(question, key, config, defaults, choices): + default = config.get(key, defaults[key]) answer = click.prompt( fmt.question(question), type=click.Choice(choices), @@ -337,11 +347,8 @@ def convert_json2yml(root): def save_config(root, config): - env.render_dict(config) path = config_path(root) - directory = os.path.dirname(path) - if not os.path.exists(directory): - os.makedirs(directory) + utils.ensure_file_directory_exists(path) with open(path, "w") as of: serialize.dump(config, of) click.echo(fmt.info("Configuration saved to {}".format(path))) diff --git a/tutor/env.py b/tutor/env.py index 006e084..b688efd 100644 --- a/tutor/env.py +++ b/tutor/env.py @@ -19,20 +19,42 @@ class Renderer: @classmethod def environment(cls): if not cls.ENVIRONMENT: - cls.ENVIRONMENT = jinja2.Environment( + environment = jinja2.Environment( loader=jinja2.FileSystemLoader(TEMPLATES_ROOT), undefined=jinja2.StrictUndefined, ) + environment.filters["random_string"] = utils.random_string + environment.filters["common_domain"] = utils.random_string + environment.filters["reverse_host"] = utils.reverse_host + cls.ENVIRONMENT = environment + return cls.ENVIRONMENT + @classmethod + def reset(cls): + cls.ENVIRONMENT = None + @classmethod def render_str(cls, config, text): - template_globals = dict( - RAND8=utils.random_string(8), RAND24=utils.random_string(24), **config - ) - template = cls.environment().from_string(text, globals=template_globals) + template = cls.environment().from_string(text) + return cls.__render(template, config) + + @classmethod + def render_file(cls, config, path): + template = cls.environment().get_template(path) try: - return template.render() + return cls.__render(template, config) + except (jinja2.exceptions.TemplateError, exceptions.TutorError): + print("Error rendering template", path) + raise + except Exception: + print("Unknown error rendering template", path) + raise + + @classmethod + def __render(cls, template, config): + try: + return template.render(**config) except jinja2.exceptions.UndefinedError as e: raise exceptions.TutorError( "Missing configuration value: {}".format(e.args[0]) @@ -58,36 +80,16 @@ def render_subdir(subdir, root, config): for path in walk_templates(subdir): dst = pathjoin(root, path) rendered = render_file(config, path) - ensure_file_directory_exists(dst) + utils.ensure_file_directory_exists(dst) with open(dst, "w") as of: of.write(rendered) -def ensure_file_directory_exists(path): - """ - Create file's base directory if it does not exist. - """ - directory = os.path.dirname(path) - if not os.path.exists(directory): - os.makedirs(directory) - - def render_file(config, *path): """ Return the rendered contents of a template. - TODO refactor this and move it to Renderer """ - with codecs.open(template_path(*path), encoding="utf-8") as fi: - try: - return render_str(config, fi.read()) - except jinja2.exceptions.UndefinedError: - raise - except jinja2.exceptions.TemplateError: - print("Error rendering template", path) - raise - except Exception: - print("Unknown error rendering template", path) - raise + return Renderer.render_file(config, os.path.join(*path)) def render_dict(config): @@ -128,7 +130,7 @@ def copy_subdir(subdir, root): for path in walk_templates(subdir): src = os.path.join(TEMPLATES_ROOT, path) dst = pathjoin(root, path) - ensure_file_directory_exists(dst) + utils.ensure_file_directory_exists(dst) shutil.copy(src, dst) diff --git a/tutor/templates/android/gradle.properties b/tutor/templates/android/gradle.properties index 8703267..5c781bf 100644 --- a/tutor/templates/android/gradle.properties +++ b/tutor/templates/android/gradle.properties @@ -1,4 +1,4 @@ -APPLICATION_ID={{ lms_host_reverse }} +APPLICATION_ID={{ LMS_HOST|reverse_host }} RELEASE_STORE_FILE=/openedx/config/app.keystore RELEASE_STORE_PASSWORD={{ ANDROID_RELEASE_STORE_PASSWORD }} RELEASE_KEY_PASSWORD={{ ANDROID_RELEASE_KEY_PASSWORD }} diff --git a/tutor/templates/config-defaults.yml b/tutor/templates/config-defaults.yml index e1cd393..ba25f5f 100644 --- a/tutor/templates/config-defaults.yml +++ b/tutor/templates/config-defaults.yml @@ -1,13 +1,33 @@ --- +# These configuration values must be stored in the user's config.yml. +LMS_HOST: "www.myopenedx.com" +CMS_HOST: "studio.{{ LMS_HOST }}" +CONTACT_EMAIL: "contact@{{ LMS_HOST }}" +SECRET_KEY: "{{ 24|random_string }}" +MYSQL_ROOT_PASSWORD: "{{ 8|random_string }}" +OPENEDX_MYSQL_PASSWORD: "{{ 8|random_string }}" +NOTES_MYSQL_PASSWORD: "{{ 8|random_string }}" +NOTES_SECRET_KEY: "{{ 24|random_string }}" +NOTES_OAUTH2_SECRET: "{{ 24|random_string }}" +XQUEUE_AUTH_PASSWORD: "{{ 8|random_string }}" +XQUEUE_MYSQL_PASSWORD: "{{ 8|random_string }}" +XQUEUE_SECRET_KEY: "{{ 24|random_string }}" +ANDROID_OAUTH2_SECRET: "{{ 24|random_string }}" +ID: "{{ 24|random_string }}" + +# The following are default values ACTIVATE_LMS: true ACTIVATE_CMS: true ACTIVATE_FORUM: true ACTIVATE_ELASTICSEARCH: true +ACTIVATE_HTTPS: false ACTIVATE_MEMCACHED: true ACTIVATE_MONGODB: true ACTIVATE_MYSQL: true +ACTIVATE_NOTES: false ACTIVATE_RABBITMQ: true ACTIVATE_SMTP: true +ACTIVATE_XQUEUE: false ANDROID_RELEASE_STORE_PASSWORD: "android store password" ANDROID_RELEASE_KEY_PASSWORD: "android release key password" ANDROID_RELEASE_KEY_ALIAS: "android release key alias" @@ -28,6 +48,7 @@ LOCAL_PROJECT_NAME: "tutor_local" ELASTICSEARCH_HOST: "elasticsearch" ELASTICSEARCH_PORT: 9200 FORUM_HOST: "forum" +LANGUAGE_CODE: "en" MEMCACHED_HOST: "memcached" MEMCACHED_PORT: 11211 MONGODB_HOST: "mongodb" @@ -44,6 +65,7 @@ NGINX_HTTPS_PORT: 443 NOTES_HOST: "notes.{{ LMS_HOST }}" NOTES_MYSQL_DATABASE: "notes" NOTES_MYSQL_USERNAME: "notes" +PLATFORM_NAME: "My Open edX" RABBITMQ_HOST: "rabbitmq" RABBITMQ_USERNAME: "" RABBITMQ_PASSWORD: "" diff --git a/tutor/utils.py b/tutor/utils.py index 2971d04..174b657 100644 --- a/tutor/utils.py +++ b/tutor/utils.py @@ -1,3 +1,4 @@ +import os import random import shutil import string @@ -9,6 +10,15 @@ from . import exceptions from . import fmt +def ensure_file_directory_exists(path): + """ + Create file's base directory if it does not exist. + """ + directory = os.path.dirname(path) + if not os.path.exists(directory): + os.makedirs(directory) + + def random_string(length): return "".join( [random.choice(string.ascii_letters + string.digits) for _ in range(length)] @@ -32,6 +42,15 @@ def common_domain(d1, d2): return ".".join(common[::-1]) +def reverse_host(domain): + """ + Return the reverse domain name, java-style. + + Ex: "www.google.com" -> "com.google.www" + """ + return ".".join(domain.split(".")[::-1]) + + def docker_run(*command): return docker("run", "--rm", "-it", *command)