Source code for haddock.libs.libgrid

import re
import shutil
import subprocess
import tempfile
import uuid
import os
import zipfile
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple

from haddock import log
from haddock.libs.libsubprocess import CNSJob

import warnings

# NOTE: When creating the payload.zip, this warning appears because we are replicating `toppar_path`
#  subdirectory structure. It can be ignored safely, but keep an eye on it to make sure
#  not many operations are being duplicated
warnings.filterwarnings("ignore", message="Duplicate name:*", category=UserWarning)
# =====================================================================#
# EXPLANATION #
# =====================================================================#
# `JOB_TYPE` is an environment variable that can be set to specify the
#  type of job associated with a given VO. This is usually important for
#  accounting purposes, as different job types may have different resources
#  associated with it.
# There is an explanation about this in the user manual.
# =====================================================================#
JOB_TYPE = os.getenv("HADDOCK3_GRID_JOB_TYPE", "WeNMR-DEV")
MAX_RETRIES = 10

# Important patterns to adjust the paths in the `.inp` file
OUTPUT_PATTERN = r'\$output_\w+\s*=\s*"([^"]+)"|\$output_\w+\s*=\s*([^\s;]+)'
VAR_PATTERN = r"\(\$\s*[^=]*=(?!.?\$)(.*)\)"  # https://regex101.com/r/dYqlZP/1
AT_PATTERN = r"@@(?!\$)(.*)"


[docs] def ping_dirac() -> bool: """Ping the Dirac server to check if it's reachable.""" if not validate_dirac(): return False result = subprocess.run(["dirac-proxy-info"], capture_output=True) if result.returncode != 0: log.error(f"Dirac proxy info failed: {result.stderr.decode().strip()}") return False return True
[docs] def validate_dirac() -> bool: """Check if the DIRAC client is valid and configured.""" expected_cmds = [ "dirac-proxy-info", "dirac-wms-job-submit", "dirac-wms-job-status", "dirac-wms-job-get-output", ] for cmd in expected_cmds: # Expect the commands to be in the PATH which_cmd = shutil.which(cmd) if not which_cmd: log.error(f"Command '{cmd}' not found in PATH.") return False return True
[docs] class JobStatus(Enum): WAITING = "Waiting" RUNNING = "Running" UNKNOWN = "Unknown" DONE = "Done" MATCHED = "Matched" COMPLETING = "Completing" FAILED = "Failed"
[docs] @classmethod def from_string(cls, value): """Convert string to JobStatus enum.""" value = value.strip().lower() for status in cls: if status.value.lower() == value: return status return cls.UNKNOWN
[docs] class GridJob: """GridJob is a class tha represents a job to be run on the GRID via DIRAC.""" def __init__(self, cnsjob: CNSJob) -> None: # Unique name for the job self.name = str(uuid.uuid4()) # Create a temporary directory for the job self.loc = Path(tempfile.mkdtemp(prefix="haddock_grid_")) # working directory where the GridJob was created, this is important # so we know where to put the files coming from the grid self.wd = Path.cwd() self.input_f = self.loc / "input.inp" # `self.input_str` is the content of the `.inp` file self.input_str = "" # `cns_out_f` is the expected `.out` file that will be generated by this execution self.cns_out_f = cnsjob.output_file # `expected_outputs` is what will be produced by this `.inp` file self.expected_outputs = [] # `id` is given by DIRAC self.id = None # Site where the job is running self.site = None # Output and error files generated by the job.sh execution self.stdout_f = None self.stderr_f = None # Script that will be executed in the grid self.job_script = self.loc / "job.sh" # JDL file that describes the job to DIRAC self.jdl = self.loc / "job.jdl" # Internal status of the job self.status = JobStatus.UNKNOWN # List of files to be included in the payload self.payload_fnames = [] # Counter for retries self.retries = 0 # NOTE: `cnsjob.input_file` can be a string or a path, handle the polymorfism here if isinstance(cnsjob.input_file, Path): self.input_str = cnsjob.input_file.read_text() elif isinstance(cnsjob.input_file, str): self.input_str = cnsjob.input_file # CREATE JOB FILE self.create_job_script() # CREATE PAYLOAD self.prepare_payload( cns_exec_path=Path(cnsjob.cns_exec), cns_script_path=Path(cnsjob.envvars["MODULE"]), toppar_path=Path(cnsjob.envvars["TOPPAR"]), ) # CREATE THE JDL self.create_jdl()
[docs] def prepare_payload( self, cns_exec_path: Path, cns_script_path: Path, toppar_path: Path ) -> None: """Prepare the payload.zip file containing all necessary files.""" # We need to process the input file first to adjust paths and identify outputs self.process_input_f() # Find the CNS scripts that should be inside the payload for f in cns_script_path.glob("*"): self.payload_fnames.append(Path(f)) # Find the TOPPAR files that should be inside the payload for f in toppar_path.rglob("*"): if f.is_file(): # Only add files, not directories self.payload_fnames.append(f) # Create the payload.zip with zipfile.ZipFile(f"{self.loc}/payload.zip", "w") as z: # It must contain the CNS executable and the input file, z.write(self.input_f, arcname="input.inp") z.write(cns_exec_path, arcname="cns") for f in set(self.payload_fnames): # NOTE: Preserve the relative path structure from `toppar_path`! # This is important because some CNS scripts have relative paths # hardcoded in it if f.is_relative_to(toppar_path): relative_path = f.relative_to(toppar_path) z.write(f, arcname=str(relative_path)) else: z.write(f, arcname=Path(f).name)
[docs] def create_job_script(self) -> None: """Create the job script that will be executed in the grid.""" # NOTE: We use `\n` instead of `os.linesep` because this will # be executed in a Linux environment inside the grid instructions = "#!/bin/bash\n" instructions += "export MODULE=./\n" instructions += "export TOPPAR=./\n" instructions += "unzip payload.zip\n" instructions += "./cns < input.inp > cns.log\n" # Remove `cns.log` if there is no error instructions += "[ $? -eq 0 ] && rm cns.log || exit 1\n" with open(self.job_script, "w") as f: f.write(instructions)
[docs] def process_input_f(self) -> None: """Process the input file to adjust paths and identify outputs. =================================================================================== !! IMPORTANT !! =================================================================================== This is a very important step. The `.inp` files coming from the modules have several paths that are absolute paths in the local filesystem. When this job is running in the GRID, the paths will no longer be valid. This processing here will identify all the paths and input files in the `.inp` file and replace them with a structure that will be valid in the GRID. =================================================================================== !! IMPORTANT !! =================================================================================== """ # Write the modified lines back with open(self.input_f, "w") as f: for line in self.input_str.splitlines(keepends=True): # Parse this line and try to identify output files output = self._find_output(line) if output: self.expected_outputs.append(output) # Process the line to adjust paths, if any new_line, found_fname = self._process_line(line) f.write(new_line) # Collect the files that need to be in the payload if found_fname: src_path = Path(found_fname) dst_path = self.loc / Path(found_fname).name shutil.copy(src_path, dst_path) self.payload_fnames.append(dst_path)
[docs] def submit(self) -> None: """Interface to submit the job to DIRAC.""" try: # The SUBMIT_CMD returns the job ID in stdout result = subprocess.run( ["dirac-wms-job-submit", f"{self.loc}/job.jdl"], shell=False, capture_output=True, text=True, cwd=self.loc, check=True, ) except subprocess.CalledProcessError as e: log.error( f"Job submission failed: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}" ) raise # Add the ID to the object self.id = int(result.stdout.split()[-1]) # Update the status self.update_status()
[docs] def retrieve_output(self) -> None: """Retrieve the output files from DIRAC.""" try: subprocess.run( ["dirac-wms-job-get-output", str(self.id)], shell=False, capture_output=True, text=True, cwd=self.loc, ) except subprocess.CalledProcessError as e: log.error( f"Retrieving output failed: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}" ) raise # NOTE: This is the output of the `job.sh` self.stdout_f = Path(f"{self.loc}/{self.id}/job.out") self.stderr_f = Path(f"{self.loc}/{self.id}/job.err") # If the job failed for some reason, save the output if self.stderr_f.exists(): dst = Path(self.wd / f"{self.id}_dirac.err") shutil.copy(self.stderr_f, dst) log.debug(f"ID stderr: {self.stderr_f.read_text()}") # Copy the output to the expected location for output_f in self.expected_outputs: src = Path(f"{self.loc}/{self.id}/{output_f}") dst = Path(self.wd / f"{output_f}") shutil.copy(src, dst) # `cns.log` is the `.out` file generated by CNS # it will only exist if there was an error src = Path(f"{self.loc}/{self.id}/cns.log") if src.exists(): dst = Path(self.wd / f"{self.cns_out_f}") shutil.copy(src, dst)
[docs] def update_status(self) -> None: """Update the status of the job by querying DIRAC.""" try: result = subprocess.run( ["dirac-wms-job-status", str(self.id)], shell=False, capture_output=True, text=True, ) except subprocess.CalledProcessError as e: log.error( f"Updating the status failed: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}" ) raise output_dict = self.parse_output(result.stdout) self.id = output_dict["JobID"] self.status = JobStatus.from_string(output_dict["Status"]) self.site = output_dict.get("Site", "Unknown")
[docs] def create_jdl(self) -> None: """Create the JDL file that describes the job to DIRAC.""" output_sandbox = [ "job.out", "job.err", "cns.log", ] output_sandbox.extend(self.expected_outputs) output_sandbox_str = ", ".join(f'"{fname}"' for fname in output_sandbox) jdl_lines = [ f'JobName = "{self.name}";', 'Executable = "job.sh";', 'Arguments = "";', 'StdOutput = "job.out";', 'StdError = "job.err";', f'JobType = "{JOB_TYPE}";', 'InputSandbox = {"job.sh", "payload.zip"};', "OutputSandbox = {" + f"{output_sandbox_str}" + "};", ] jdl_string = "[\n " + "\n ".join(jdl_lines) + "\n]\n" with open(self.jdl, "w") as f: f.write(jdl_string)
[docs] def clean(self) -> None: """Clean up the temporary directory where the job lives.""" shutil.rmtree(self.loc)
[docs] @staticmethod def parse_output(output_str: str) -> dict[str, str]: """Parse the output string from DIRAC commands into a dictionary.""" items = output_str.replace(";", "") status_dict = {} for item in items.split(" "): if "=" in item: key, value = item.split("=", 1) status_dict[key.strip()] = value.strip() return status_dict
@staticmethod def _process_line(line: str) -> Tuple[str, Optional[str]]: """Process a line to identify and adjust paths.""" match_var = re.findall(VAR_PATTERN, line) match_at = re.findall(AT_PATTERN, line) # NOTE: In CNS it cannot match both patterns at the same time if match_at: item = match_at[0].strip('"').strip("'") elif match_var: item = match_var[0].strip('"').strip("'") else: # no match return line, None if Path(item).exists(): # This is a path return line.replace(item, Path(item).name), item else: # This is not a path return line, None @staticmethod def _find_output(line) -> Optional[str]: """Parse the line and identify if this contains an output file declaration.""" match = re.search(OUTPUT_PATTERN, line) if match: return match.group(1) if match.group(1) else match.group(2) return None def __repr__(self) -> str: return f"ID: {self.id} Name: {self.name} Output: {self.expected_outputs} Status: {self.status.value} Site: {self.site}"
[docs] @dataclass class GRIDScheduler: """Scheduler to manage and run jobs on the GRID via DIRAC.""" tasks: list[CNSJob] params: dict
[docs] def run(self) -> None: """Execute the tasks.""" queue = {} # Convert CNSJobs to GridJobs queue = {GridJob(t): False for t in self.tasks} log.info(f"Submitting {len(queue)} jobs to the grid...") # Submit jobs # =====================================================================# # EXPLANATION # # =====================================================================# # Use multiple threads to submit jobs in parallel. When submitting jobs # each `.submit()` call can take a few seconds, so using multiple # processes can speed up the submission significantly. # # `ThreadPoolExecutor` is used to initialize a pool of threads, and # then `executor.map` is used to apply the `submit` method to each job # this will effectively submit all jobs in parallel, using # `self.params["ncores"]` threads. # =====================================================================# with ThreadPoolExecutor(max_workers=self.params["ncores"]) as executor: executor.map(lambda job: job.submit(), queue.keys()) log.info("All jobs submitted.") # Wait for jobs to finish log.info("Checking job status...") total = len(queue) complete = False while not complete: # Only process jobs that are not done yet pending_jobs = [job for job, done in queue.items() if not done] if pending_jobs: # =====================================================================# # EXPLANATION # # =====================================================================# # Use multiple threads to check job status in parallel. When checking # the status of jobs, each `.update_status()` call can take a few # seconds, so using multiple threads can speed up the process # # `ThreadPoolExecutor` is used to initialize a pool of threads, and # then `executor.map` is used to apply the `process_job` method to each # job, this will effectively check the status of all jobs in parallel, # using `self.params["ncores"]` threads. # =====================================================================# with ThreadPoolExecutor(max_workers=self.params["ncores"]) as executor: # =====================================================================# # EXPLANATION # # =====================================================================# # The function `process_job` will return a tuple with the job and a # boolean. This return values are captured in `results` list. # The boolean will be True if the job is done (either successfully or # failed after retries), and False otherwise - meaning the job is # still running or waiting. # =====================================================================# results = list(executor.map(self.process_job, pending_jobs)) for job, is_done in results: # =====================================================================# # EXPLANATION # # =====================================================================# # Here we update the `queue` dictionary to avoid checking the jobs that # are already done. # =====================================================================# if is_done: queue[job] = True # Do some simple logging to keep track of progress. done = sum(1 for done in queue.values() if done) log.info(f"{done}/{total} jobs completed.") complete = all(queue.values()) log.info("All jobs completed.")
[docs] @staticmethod def process_job(job: GridJob) -> Tuple[GridJob, bool]: """Process a single job: update status, retrieve output if done, handle retries if failed. NOTE: This function is parallelized, that is why things like cleaning, download, retry are here. If you are adding new functionality, consider if it should be here or in the sequential part of the code. """ job.update_status() is_complete = job.status in {JobStatus.DONE, JobStatus.FAILED} if job.status == JobStatus.FAILED: # Jobs on the grid can fail for many reasons outside our control # So if the job is failed, resubmit it up to MAX_RETRIES times if job.retries < MAX_RETRIES: is_complete = False expected_output_str = ",".join(job.expected_outputs) job.retries += 1 log.warning( f"Job {job.name} ({expected_output_str}) failed, re-submitting - {job.retries}/{MAX_RETRIES}" ) job.submit() if job.status == JobStatus.DONE: log.debug(f"Job {job.name} is done, retrieving output...") job.retrieve_output() job.clean() return job, is_complete