import re
import shutil
import subprocess
import tempfile
import uuid
import os
import zipfile
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple
from haddock import log
from haddock.libs.libsubprocess import CNSJob
import warnings
# NOTE: When creating the payload.zip, this warning appears because we are replicating `toppar_path`
# subdirectory structure. It can be ignored safely, but keep an eye on it to make sure
# not many operations are being duplicated
warnings.filterwarnings("ignore", message="Duplicate name:*", category=UserWarning)
# =====================================================================#
# EXPLANATION #
# =====================================================================#
# `JOB_TYPE` is an environment variable that can be set to specify the
# type of job associated with a given VO. This is usually important for
# accounting purposes, as different job types may have different resources
# associated with it.
# There is an explanation about this in the user manual.
# =====================================================================#
JOB_TYPE = os.getenv("HADDOCK3_GRID_JOB_TYPE", "WeNMR-DEV")
MAX_RETRIES = 10
# Important patterns to adjust the paths in the `.inp` file
OUTPUT_PATTERN = r'\$output_\w+\s*=\s*"([^"]+)"|\$output_\w+\s*=\s*([^\s;]+)'
VAR_PATTERN = r"\(\$\s*[^=]*=(?!.?\$)(.*)\)" # https://regex101.com/r/dYqlZP/1
AT_PATTERN = r"@@(?!\$)(.*)"
[docs]
def ping_dirac() -> bool:
"""Ping the Dirac server to check if it's reachable."""
if not validate_dirac():
return False
result = subprocess.run(["dirac-proxy-info"], capture_output=True)
if result.returncode != 0:
log.error(f"Dirac proxy info failed: {result.stderr.decode().strip()}")
return False
return True
[docs]
def validate_dirac() -> bool:
"""Check if the DIRAC client is valid and configured."""
expected_cmds = [
"dirac-proxy-info",
"dirac-wms-job-submit",
"dirac-wms-job-status",
"dirac-wms-job-get-output",
]
for cmd in expected_cmds:
# Expect the commands to be in the PATH
which_cmd = shutil.which(cmd)
if not which_cmd:
log.error(f"Command '{cmd}' not found in PATH.")
return False
return True
[docs]
class JobStatus(Enum):
WAITING = "Waiting"
RUNNING = "Running"
UNKNOWN = "Unknown"
DONE = "Done"
MATCHED = "Matched"
COMPLETING = "Completing"
FAILED = "Failed"
[docs]
@classmethod
def from_string(cls, value):
"""Convert string to JobStatus enum."""
value = value.strip().lower()
for status in cls:
if status.value.lower() == value:
return status
return cls.UNKNOWN
[docs]
class GridJob:
"""GridJob is a class tha represents a job to be run on the GRID via DIRAC."""
def __init__(self, cnsjob: CNSJob) -> None:
# Unique name for the job
self.name = str(uuid.uuid4())
# Create a temporary directory for the job
self.loc = Path(tempfile.mkdtemp(prefix="haddock_grid_"))
# working directory where the GridJob was created, this is important
# so we know where to put the files coming from the grid
self.wd = Path.cwd()
self.input_f = self.loc / "input.inp"
# `self.input_str` is the content of the `.inp` file
self.input_str = ""
# `cns_out_f` is the expected `.out` file that will be generated by this execution
self.cns_out_f = cnsjob.output_file
# `expected_outputs` is what will be produced by this `.inp` file
self.expected_outputs = []
# `id` is given by DIRAC
self.id = None
# Site where the job is running
self.site = None
# Output and error files generated by the job.sh execution
self.stdout_f = None
self.stderr_f = None
# Script that will be executed in the grid
self.job_script = self.loc / "job.sh"
# JDL file that describes the job to DIRAC
self.jdl = self.loc / "job.jdl"
# Internal status of the job
self.status = JobStatus.UNKNOWN
# List of files to be included in the payload
self.payload_fnames = []
# Counter for retries
self.retries = 0
# NOTE: `cnsjob.input_file` can be a string or a path, handle the polymorfism here
if isinstance(cnsjob.input_file, Path):
self.input_str = cnsjob.input_file.read_text()
elif isinstance(cnsjob.input_file, str):
self.input_str = cnsjob.input_file
# CREATE JOB FILE
self.create_job_script()
# CREATE PAYLOAD
self.prepare_payload(
cns_exec_path=Path(cnsjob.cns_exec),
cns_script_path=Path(cnsjob.envvars["MODULE"]),
toppar_path=Path(cnsjob.envvars["TOPPAR"]),
)
# CREATE THE JDL
self.create_jdl()
[docs]
def prepare_payload(
self, cns_exec_path: Path, cns_script_path: Path, toppar_path: Path
) -> None:
"""Prepare the payload.zip file containing all necessary files."""
# We need to process the input file first to adjust paths and identify outputs
self.process_input_f()
# Find the CNS scripts that should be inside the payload
for f in cns_script_path.glob("*"):
self.payload_fnames.append(Path(f))
# Find the TOPPAR files that should be inside the payload
for f in toppar_path.rglob("*"):
if f.is_file(): # Only add files, not directories
self.payload_fnames.append(f)
# Create the payload.zip
with zipfile.ZipFile(f"{self.loc}/payload.zip", "w") as z:
# It must contain the CNS executable and the input file,
z.write(self.input_f, arcname="input.inp")
z.write(cns_exec_path, arcname="cns")
for f in set(self.payload_fnames):
# NOTE: Preserve the relative path structure from `toppar_path`!
# This is important because some CNS scripts have relative paths
# hardcoded in it
if f.is_relative_to(toppar_path):
relative_path = f.relative_to(toppar_path)
z.write(f, arcname=str(relative_path))
else:
z.write(f, arcname=Path(f).name)
[docs]
def create_job_script(self) -> None:
"""Create the job script that will be executed in the grid."""
# NOTE: We use `\n` instead of `os.linesep` because this will
# be executed in a Linux environment inside the grid
instructions = "#!/bin/bash\n"
instructions += "export MODULE=./\n"
instructions += "export TOPPAR=./\n"
instructions += "unzip payload.zip\n"
instructions += "./cns < input.inp > cns.log\n"
# Remove `cns.log` if there is no error
instructions += "[ $? -eq 0 ] && rm cns.log || exit 1\n"
with open(self.job_script, "w") as f:
f.write(instructions)
[docs]
def submit(self) -> None:
"""Interface to submit the job to DIRAC."""
try:
# The SUBMIT_CMD returns the job ID in stdout
result = subprocess.run(
["dirac-wms-job-submit", f"{self.loc}/job.jdl"],
shell=False,
capture_output=True,
text=True,
cwd=self.loc,
check=True,
)
except subprocess.CalledProcessError as e:
log.error(
f"Job submission failed: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}"
)
raise
# Add the ID to the object
self.id = int(result.stdout.split()[-1])
# Update the status
self.update_status()
[docs]
def retrieve_output(self) -> None:
"""Retrieve the output files from DIRAC."""
try:
subprocess.run(
["dirac-wms-job-get-output", str(self.id)],
shell=False,
capture_output=True,
text=True,
cwd=self.loc,
)
except subprocess.CalledProcessError as e:
log.error(
f"Retrieving output failed: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}"
)
raise
# NOTE: This is the output of the `job.sh`
self.stdout_f = Path(f"{self.loc}/{self.id}/job.out")
self.stderr_f = Path(f"{self.loc}/{self.id}/job.err")
# If the job failed for some reason, save the output
if self.stderr_f.exists():
dst = Path(self.wd / f"{self.id}_dirac.err")
shutil.copy(self.stderr_f, dst)
log.debug(f"ID stderr: {self.stderr_f.read_text()}")
# Copy the output to the expected location
for output_f in self.expected_outputs:
src = Path(f"{self.loc}/{self.id}/{output_f}")
dst = Path(self.wd / f"{output_f}")
shutil.copy(src, dst)
# `cns.log` is the `.out` file generated by CNS
# it will only exist if there was an error
src = Path(f"{self.loc}/{self.id}/cns.log")
if src.exists():
dst = Path(self.wd / f"{self.cns_out_f}")
shutil.copy(src, dst)
[docs]
def update_status(self) -> None:
"""Update the status of the job by querying DIRAC."""
try:
result = subprocess.run(
["dirac-wms-job-status", str(self.id)],
shell=False,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as e:
log.error(
f"Updating the status failed: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}"
)
raise
output_dict = self.parse_output(result.stdout)
self.id = output_dict["JobID"]
self.status = JobStatus.from_string(output_dict["Status"])
self.site = output_dict.get("Site", "Unknown")
[docs]
def create_jdl(self) -> None:
"""Create the JDL file that describes the job to DIRAC."""
output_sandbox = [
"job.out",
"job.err",
"cns.log",
]
output_sandbox.extend(self.expected_outputs)
output_sandbox_str = ", ".join(f'"{fname}"' for fname in output_sandbox)
jdl_lines = [
f'JobName = "{self.name}";',
'Executable = "job.sh";',
'Arguments = "";',
'StdOutput = "job.out";',
'StdError = "job.err";',
f'JobType = "{JOB_TYPE}";',
'InputSandbox = {"job.sh", "payload.zip"};',
"OutputSandbox = {" + f"{output_sandbox_str}" + "};",
]
jdl_string = "[\n " + "\n ".join(jdl_lines) + "\n]\n"
with open(self.jdl, "w") as f:
f.write(jdl_string)
[docs]
def clean(self) -> None:
"""Clean up the temporary directory where the job lives."""
shutil.rmtree(self.loc)
[docs]
@staticmethod
def parse_output(output_str: str) -> dict[str, str]:
"""Parse the output string from DIRAC commands into a dictionary."""
items = output_str.replace(";", "")
status_dict = {}
for item in items.split(" "):
if "=" in item:
key, value = item.split("=", 1)
status_dict[key.strip()] = value.strip()
return status_dict
@staticmethod
def _process_line(line: str) -> Tuple[str, Optional[str]]:
"""Process a line to identify and adjust paths."""
match_var = re.findall(VAR_PATTERN, line)
match_at = re.findall(AT_PATTERN, line)
# NOTE: In CNS it cannot match both patterns at the same time
if match_at:
item = match_at[0].strip('"').strip("'")
elif match_var:
item = match_var[0].strip('"').strip("'")
else:
# no match
return line, None
if Path(item).exists():
# This is a path
return line.replace(item, Path(item).name), item
else:
# This is not a path
return line, None
@staticmethod
def _find_output(line) -> Optional[str]:
"""Parse the line and identify if this contains an output file declaration."""
match = re.search(OUTPUT_PATTERN, line)
if match:
return match.group(1) if match.group(1) else match.group(2)
return None
def __repr__(self) -> str:
return f"ID: {self.id} Name: {self.name} Output: {self.expected_outputs} Status: {self.status.value} Site: {self.site}"
[docs]
@dataclass
class GRIDScheduler:
"""Scheduler to manage and run jobs on the GRID via DIRAC."""
tasks: list[CNSJob]
params: dict
[docs]
def run(self) -> None:
"""Execute the tasks."""
queue = {}
# Convert CNSJobs to GridJobs
queue = {GridJob(t): False for t in self.tasks}
log.info(f"Submitting {len(queue)} jobs to the grid...")
# Submit jobs
# =====================================================================#
# EXPLANATION #
# =====================================================================#
# Use multiple threads to submit jobs in parallel. When submitting jobs
# each `.submit()` call can take a few seconds, so using multiple
# processes can speed up the submission significantly.
#
# `ThreadPoolExecutor` is used to initialize a pool of threads, and
# then `executor.map` is used to apply the `submit` method to each job
# this will effectively submit all jobs in parallel, using
# `self.params["ncores"]` threads.
# =====================================================================#
with ThreadPoolExecutor(max_workers=self.params["ncores"]) as executor:
executor.map(lambda job: job.submit(), queue.keys())
log.info("All jobs submitted.")
# Wait for jobs to finish
log.info("Checking job status...")
total = len(queue)
complete = False
while not complete:
# Only process jobs that are not done yet
pending_jobs = [job for job, done in queue.items() if not done]
if pending_jobs:
# =====================================================================#
# EXPLANATION #
# =====================================================================#
# Use multiple threads to check job status in parallel. When checking
# the status of jobs, each `.update_status()` call can take a few
# seconds, so using multiple threads can speed up the process
#
# `ThreadPoolExecutor` is used to initialize a pool of threads, and
# then `executor.map` is used to apply the `process_job` method to each
# job, this will effectively check the status of all jobs in parallel,
# using `self.params["ncores"]` threads.
# =====================================================================#
with ThreadPoolExecutor(max_workers=self.params["ncores"]) as executor:
# =====================================================================#
# EXPLANATION #
# =====================================================================#
# The function `process_job` will return a tuple with the job and a
# boolean. This return values are captured in `results` list.
# The boolean will be True if the job is done (either successfully or
# failed after retries), and False otherwise - meaning the job is
# still running or waiting.
# =====================================================================#
results = list(executor.map(self.process_job, pending_jobs))
for job, is_done in results:
# =====================================================================#
# EXPLANATION #
# =====================================================================#
# Here we update the `queue` dictionary to avoid checking the jobs that
# are already done.
# =====================================================================#
if is_done:
queue[job] = True
# Do some simple logging to keep track of progress.
done = sum(1 for done in queue.values() if done)
log.info(f"{done}/{total} jobs completed.")
complete = all(queue.values())
log.info("All jobs completed.")
[docs]
@staticmethod
def process_job(job: GridJob) -> Tuple[GridJob, bool]:
"""Process a single job: update status, retrieve output if done, handle retries if failed.
NOTE: This function is parallelized, that is why things like cleaning, download, retry
are here. If you are adding new functionality, consider if it should be here or in the
sequential part of the code.
"""
job.update_status()
is_complete = job.status in {JobStatus.DONE, JobStatus.FAILED}
if job.status == JobStatus.FAILED:
# Jobs on the grid can fail for many reasons outside our control
# So if the job is failed, resubmit it up to MAX_RETRIES times
if job.retries < MAX_RETRIES:
is_complete = False
expected_output_str = ",".join(job.expected_outputs)
job.retries += 1
log.warning(
f"Job {job.name} ({expected_output_str}) failed, re-submitting - {job.retries}/{MAX_RETRIES}"
)
job.submit()
if job.status == JobStatus.DONE:
log.debug(f"Job {job.name} is done, retrieving output...")
job.retrieve_output()
job.clean()
return job, is_complete