diff --git a/CHANGELOG.md b/CHANGELOG.md index 75d1beb..1779493 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Latest +- [Improvement] Safer environment overwrite. Thanks @silviot! 👐 - [Security] Fix Jinja2 vulnerability - [Improvement] Improve CLI cold start performance - [Improvement] Allow uppercase "Y" and "N" as answers to boolean questions diff --git a/tutor/config.py b/tutor/config.py index dc43b18..ffb7fc7 100644 --- a/tutor/config.py +++ b/tutor/config.py @@ -20,6 +20,7 @@ from .__about__ import __version__ def config(): pass + @click.command(help="Create and save configuration interactively") @opts.root @click.option("-y", "--yes", "silent1", is_flag=True, help="Run non-interactively") @@ -29,6 +30,7 @@ def save_command(root, silent1, silent2, set_): silent = silent1 or silent2 save(root, silent=silent, keyvalues=set_) + def save(root, silent=False, keyvalues=None): keyvalues = keyvalues or [] config = {} @@ -42,6 +44,7 @@ def save(root, silent=False, keyvalues=None): load_defaults(config) save_env(root, config) + @click.command( help="Print the project root", ) @@ -49,6 +52,7 @@ def save(root, silent=False, keyvalues=None): def printroot(root): click.echo(root) + @click.command(help="Print a configuration value") @opts.root @click.argument("key") @@ -61,6 +65,7 @@ def printvalue(root, key): except KeyError: raise exceptions.TutorError("Missing configuration value: {}".format(key)) + def load(root): """ Load configuration, and generate it interactively if the file does not @@ -78,42 +83,55 @@ def load(root): load_defaults(config) if not env.is_up_to_date(root): should_update_env = True - message = ( - "The current environment stored at {} is not up-to-date: it is at " - "v{} while the 'tutor' binary is at v{}.".format( - env.base_dir(root), env.version(root), __version__ - ) - ) - if os.isatty(sys.stdin.fileno()): - # Interactive mode: ask the user permission to proceed - confirmation_msg = ("If you choose Y The environment will be " - "upgraded now.\nAny change you might have made will be overwritten.\nProceed?") - click.confirm(fmt.alert(message + '\n' + confirmation_msg), abort=True) - elif os.environ.get('TUTOR_OVERWRITE_ENV'): - pass # Non-interactive mode with environment variable authorizing us to go - else: - # Non-interactive mode with no authorization: abort - post_message = ("Set the TUTOR_OVERWRITE_ENV variable " - "to allow tutor to rewrite the environment" - "in a non-interactive run.") - raise click.UsageError(message + "\n" + post_message) + pre_upgrade_announcement(root) if should_update_env: save_env(root, config) return config + +def pre_upgrade_announcement(root): + """ + Inform the user that the current environment is not up-to-date. Crash if running in + non-interactive mode. + """ + click.echo(fmt.alert( + "The current environment stored at {} is not up-to-date: it is at " + "v{} while the 'tutor' binary is at v{}.".format( + env.base_dir(root), env.version(root), __version__ + ) + )) + if os.isatty(sys.stdin.fileno()): + # Interactive mode: ask the user permission to proceed + click.confirm(fmt.question( + # every patch you take, every change you make, I'll be watching you + "Would you like to upgrade the environment? If you do, any change you" + " might have made will be overwritten." + ), default=True, abort=True) + else: + # Non-interactive mode with no authorization: abort + raise exceptions.TutorError( + "Running in non-interactive mode, the environment will not be upgraded" + " automatically. To upgrade the environment manually, run:\n" + "\n" + " tutor config save -y" + ) + + def load_current(config, root): convert_json2yml(root) load_base(config, root) load_user(config, root) load_env(config, root) + def load_base(config, root): base = serialize.load(env.read("config-base.yml")) for k, v in base.items(): config[k] = v + def load_env(config, root): base_config = serialize.load(env.read("config-base.yml")) default_config = serialize.load(env.read("config-defaults.yml")) @@ -124,6 +142,7 @@ def load_env(config, root): if env_var in os.environ: config[k] = serialize.parse_value(os.environ[env_var]) + def load_user(config, root): path = config_path(root) if os.path.exists(path): @@ -133,6 +152,7 @@ def load_user(config, root): config[key] = value upgrade_obsolete(config) + def upgrade_obsolete(config): # Openedx-specific mysql passwords if "MYSQL_PASSWORD" in config: @@ -144,6 +164,7 @@ def upgrade_obsolete(config): if "MYSQL_USERNAME" in config: config["OPENEDX_MYSQL_USERNAME"] = config.pop("MYSQL_USERNAME") + def load_interactive(config): ask("Your website domain name for students (LMS)", "LMS_HOST", config) ask("Your website domain name for teachers (CMS)", "CMS_HOST", config) @@ -176,6 +197,7 @@ def load_interactive(config): "ACTIVATE_XQUEUE", config ) + def load_defaults(config): defaults = serialize.load(env.read("config-defaults.yml")) for k, v in defaults.items(): @@ -186,6 +208,7 @@ def load_defaults(config): config["lms_cms_common_domain"] = utils.common_domain(config["LMS_HOST"], config["CMS_HOST"]) config["lms_host_reverse"] = ".".join(config["LMS_HOST"].split(".")[::-1]) + def ask(question, key, config): default = env.render_str(config, config[key]) config[key] = click.prompt( @@ -193,25 +216,27 @@ def ask(question, key, config): prompt_suffix=" ", default=default, show_default=True, ) + def ask_bool(question, key, config): - default = "y" if config[key] else "n" - suffix = " [Yn]" if config[key] else " [yN]" - answer = click.prompt( - fmt.question(question) + suffix, - type=click.Choice(["y", "Y", "n", "N"]), - prompt_suffix=" ", default=default, show_default=False, show_choices=False, - ).lower() - config[key] = answer == "y" + return click.confirm( + fmt.question(question), + prompt_suffix=' ', + default=config[key], + ) + def ask_choice(question, key, config, choices): default = config[key] answer = click.prompt( fmt.question(question), type=click.Choice(choices), - prompt_suffix=" ", default=default, show_choices=False, + prompt_suffix=" ", + default=default, + show_choices=False, ) config[key] = answer + def convert_json2yml(root): json_path = os.path.join(root, "config.json") if not os.path.exists(json_path): @@ -226,6 +251,7 @@ def convert_json2yml(root): os.remove(json_path) click.echo(fmt.info("File config.json detected in {} and converted to config.yml".format(root))) + def save_config(root, config): env.render_dict(config) path = config_path(root) @@ -236,13 +262,16 @@ def save_config(root, config): serialize.dump(config, of) click.echo(fmt.info("Configuration saved to {}".format(path))) + def save_env(root, config): env.render_full(root, config) click.echo(fmt.info("Environment generated in {}".format(env.base_dir(root)))) + def config_path(root): return os.path.join(root, "config.yml") + config.add_command(save_command, name="save") config.add_command(printroot) config.add_command(printvalue) diff --git a/tutor/env.py b/tutor/env.py index 5387884..62a7bb8 100644 --- a/tutor/env.py +++ b/tutor/env.py @@ -12,6 +12,7 @@ from .__about__ import __version__ TEMPLATES_ROOT = os.path.join(os.path.dirname(__file__), "templates") VERSION_FILENAME = "version" + def render_full(root, config): """ Render the full environment, including version information. @@ -22,6 +23,7 @@ def render_full(root, config): with open(pathjoin(root, VERSION_FILENAME), 'w') as f: f.write(__version__) + def render_target(root, config, target): """ Render the templates located in `target` and store them with the same @@ -32,6 +34,7 @@ def render_target(root, config, target): with open(dst, "w") as of: of.write(rendered) + def render_file(config, path): with codecs.open(path, encoding='utf-8') as fi: try: @@ -43,6 +46,7 @@ def render_file(config, path): print("Unknown error rendering template", path) raise + def render_dict(config): """ Render the values from the dict. This is useful for rendering the default @@ -61,6 +65,7 @@ def render_dict(config): config[k] = v pass + def render_str(config, text): """ Args: @@ -80,6 +85,7 @@ def render_str(config, text): except jinja2.exceptions.UndefinedError as e: raise exceptions.TutorError("Missing configuration value: {}".format(e.args[0])) + def copy_target(root, target): """ Copy the templates located in `path` and store them with the same hierarchy @@ -88,9 +94,11 @@ def copy_target(root, target): for src, dst in walk_templates(root, target): shutil.copy(src, dst) + def is_up_to_date(root): return version(root) == __version__ + def version(root): """ Return the current environment version. @@ -100,6 +108,7 @@ def version(root): return "0" return open(path).read().strip() + def read(*path): """ Read template content located at `path`. @@ -108,6 +117,7 @@ def read(*path): with codecs.open(src, encoding='utf-8') as fi: return fi.read() + def walk_templates(root, target): """ Iterate on the template files from `templates/target`. @@ -132,6 +142,7 @@ def walk_templates(root, target): if is_part_of_env(src): yield src, dst + def is_part_of_env(path): basename = os.path.basename(path) return not ( @@ -140,14 +151,18 @@ def is_part_of_env(path): basename == "__pycache__" ) + def template_path(*path): return os.path.join(TEMPLATES_ROOT, *path) + def data_path(root, *path): return os.path.join(os.path.abspath(root), "data", *path) + def pathjoin(root, target, *path): return os.path.join(base_dir(root), target, *path) + def base_dir(root): return os.path.join(root, "env")