diff --git a/tests/test_scripts.py b/tests/test_scripts.py new file mode 100644 index 0000000..6cf0954 --- /dev/null +++ b/tests/test_scripts.py @@ -0,0 +1,20 @@ +import unittest +import unittest.mock + +from tutor.config import load_defaults +from tutor import env +from tutor import scripts + + +class ScriptsTests(unittest.TestCase): + def test_run_script(self): + config = {} + load_defaults({}) + rendered_script = env.render_file("scripts", "create_databases.sh") + with unittest.mock.Mock() as run_func: + scripts.run_script( + "/tmp", config, "someservice", "create_databases.sh", run_func + ) + run_func.assert_called_once_with( + "/tmp", config, "someservice", rendered_script + ) diff --git a/tutor/scripts.py b/tutor/scripts.py index ae62469..0a79a6c 100644 --- a/tutor/scripts.py +++ b/tutor/scripts.py @@ -9,26 +9,26 @@ def migrate(root, run_func): config = tutor_config.load(root) click.echo(fmt.info("Creating all databases...")) - run_template(root, config, "mysql-client", "create_databases.sh", run_func) + run_script(root, config, "mysql-client", "create_databases.sh", run_func) if config["ACTIVATE_LMS"]: click.echo(fmt.info("Running lms migrations...")) - run_template(root, config, "lms", "migrate_lms.sh", run_func) + run_script(root, config, "lms", "migrate_lms.sh", run_func) if config["ACTIVATE_CMS"]: click.echo(fmt.info("Running cms migrations...")) - run_template(root, config, "cms", "migrate_cms.sh", run_func) + run_script(root, config, "cms", "migrate_cms.sh", run_func) if config["ACTIVATE_FORUM"]: click.echo(fmt.info("Running forum migrations...")) - run_template(root, config, "forum", "migrate_forum.sh", run_func) + run_script(root, config, "forum", "migrate_forum.sh", run_func) if config["ACTIVATE_NOTES"]: click.echo(fmt.info("Running notes migrations...")) - run_template(root, config, "notes", "migrate_django.sh", run_func) + run_script(root, config, "notes", "migrate_django.sh", run_func) if config["ACTIVATE_XQUEUE"]: click.echo(fmt.info("Running xqueue migrations...")) - run_template(root, config, "xqueue", "migrate_django.sh", run_func) + run_script(root, config, "xqueue", "migrate_django.sh", run_func) if config["ACTIVATE_LMS"]: click.echo(fmt.info("Creating oauth2 users...")) - run_template(root, config, "lms", "oauth2.sh", run_func) + run_script(root, config, "lms", "oauth2.sh", run_func) click.echo(fmt.info("Databases ready.")) @@ -38,23 +38,18 @@ def create_user(root, run_func, superuser, staff, name, email): config["OPTS"] += " --superuser" if staff: config["OPTS"] += " --staff" - run_template(root, config, "lms", "create_user.sh", run_func) + run_script(root, config, "lms", "create_user.sh", run_func) def import_demo_course(root, run_func): - run_template(root, {}, "cms", "import_demo_course.sh", run_func) + run_script(root, {}, "cms", "import_demo_course.sh", run_func) def index_courses(root, run_func): - run_template(root, {}, "cms", "index_courses.sh", run_func) + run_script(root, {}, "cms", "index_courses.sh", run_func) -def run_template(root, config, service, template, run_func): - command = render_template(config, template) +def run_script(root, config, service, template, run_func): + command = env.render_file(config, "script", template).strip() if command: run_func(root, service, command) - - -def render_template(config, template): - path = env.template_path("scripts", template) - return env.render_file(config, path).strip()