diff --git a/changelog.d/20230214_105510_keith_fix_jobs_merge.md b/changelog.d/20230214_105510_keith_fix_jobs_merge.md new file mode 100644 index 0000000..4bd44fd --- /dev/null +++ b/changelog.d/20230214_105510_keith_fix_jobs_merge.md @@ -0,0 +1,12 @@ + + + + +[Bugfix] `patchStrategicMerge` can now be applied to jobs (by @keithgg) diff --git a/tutor/commands/k8s.py b/tutor/commands/k8s.py index 8df275c..d26d7cc 100644 --- a/tutor/commands/k8s.py +++ b/tutor/commands/k8s.py @@ -1,6 +1,6 @@ from datetime import datetime from time import sleep -from typing import Any, List, Optional, Type +from typing import Any, List, Optional, Type, Iterable import click @@ -64,11 +64,12 @@ class K8sTaskRunner(BaseTaskRunner): """ def run_task(self, service: str, command: str) -> int: - job_name = f"{service}-job" - job = self.load_job(job_name) + canonical_job_name = f"{service}-job" + all_jobs = list(self._load_jobs()) + job = self._find_job(canonical_job_name, all_jobs) # Create a unique job name to make it deduplicate jobs and make it easier to # find later. Logs of older jobs will remain available for some time. - job_name += "-" + datetime.now().strftime("%Y%m%d%H%M%S") + job_name = canonical_job_name + "-" + datetime.now().strftime("%Y%m%d%H%M%S") # Wait until all other jobs are completed while True: @@ -98,11 +99,12 @@ class K8sTaskRunner(BaseTaskRunner): job["spec"]["template"]["spec"]["containers"][0]["args"] = container_args job["spec"]["backoffLimit"] = 1 job["spec"]["ttlSecondsAfterFinished"] = 3600 - # Save patched job to "jobs.yml" file + with open( tutor_env.pathjoin(self.root, "k8s", "jobs.yml"), "w", encoding="utf-8" ) as job_file: - serialize.dump(job, job_file) + serialize.dump_all(all_jobs, job_file) + # We cannot use the k8s API to create the job: configMap and volume names need # to be found with the right suffixes. kubectl_apply( @@ -143,8 +145,15 @@ class K8sTaskRunner(BaseTaskRunner): """ Find a given job definition in the rendered k8s/jobs.yml template. """ - all_jobs = self.render("k8s", "jobs.yml") - for job in serialize.load_all(all_jobs): + return self._find_job(name, self._load_jobs()) + + def _find_job(self, name: str, all_jobs: Iterable[Any]) -> Any: + """ + Find the matching job definition in the in the list of jobs provided. + + Returns the found job's manifest. + """ + for job in all_jobs: job_name = job["metadata"]["name"] if not isinstance(job_name, str): raise exceptions.TutorError( @@ -154,6 +163,12 @@ class K8sTaskRunner(BaseTaskRunner): return job raise exceptions.TutorError(f"Could not find job '{name}'") + def _load_jobs(self) -> Iterable[Any]: + manifests = self.render("k8s", "jobs.yml") + for manifest in serialize.load_all(manifests): + if manifest["kind"] == "Job": + yield manifest + def active_job_names(self) -> List[str]: """ Return a list of active job names diff --git a/tutor/serialize.py b/tutor/serialize.py index b46d414..0833749 100644 --- a/tutor/serialize.py +++ b/tutor/serialize.py @@ -17,6 +17,10 @@ def load_all(stream: str) -> t.Iterator[t.Any]: return yaml.load_all(stream, Loader=yaml.SafeLoader) +def dump_all(documents: t.Sequence[t.Any], fileobj: TextIOWrapper) -> None: + yaml.safe_dump_all(documents, stream=fileobj, default_flow_style=False) + + def dump(content: t.Any, fileobj: TextIOWrapper) -> None: yaml.dump(content, stream=fileobj, default_flow_style=False)