import glob
import math
import os
import random
import re
import shutil
import subprocess
import tempfile
import uuid
import warnings
import zipfile
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple, Union
from haddock import log
from haddock.core.defaults import cns_exec_linux as CNS_EXEC
from haddock.libs.libsubprocess import CNSJob
from haddock.libs.libutil import parse_ncores
# NOTE: When creating the payload.zip, this warning appears because we are replicating `toppar_path`
# subdirectory structure. It can be ignored safely, but keep an eye on it to make sure
# not many operations are being duplicated
warnings.filterwarnings("ignore", message="Duplicate name:*", category=UserWarning)
# =====================================================================#
# EXPLANATION #
# =====================================================================#
# `JOB_TYPE` is an environment variable that can be set to specify the
# type of job associated with a given VO. This is usually important for
# accounting purposes, as different job types may have different resources
# associated with it.
# There is an explanation about this in the user manual.
# =====================================================================#
JOB_TYPE = os.getenv("HADDOCK3_GRID_JOB_TYPE", "WeNMR-DEV")
MAX_RETRIES = 10
# Important patterns to adjust the paths in the `.inp` file
OUTPUT_PATTERN = r'\$output_\w+\s*=\s*"([^"]+)"|\$output_\w+\s*=\s*([^\s;]+)'
VAR_PATTERN = r"\(\$\s*[^=]*=(?!.?\$)(.*)\)" # https://regex101.com/r/dYqlZP/1
AT_PATTERN = r"@@(?!\$)(.*)"
[docs]
def ping_dirac() -> bool:
"""Ping the Dirac server to check if it's reachable."""
if not validate_dirac():
return False
result = subprocess.run(["dirac-proxy-info"], capture_output=True)
if result.returncode != 0:
log.error(f"Dirac proxy info failed: {result.stderr.decode().strip()}")
return False
return True
[docs]
def validate_dirac() -> bool:
"""Check if the DIRAC client is valid and configured."""
expected_cmds = [
"dirac-proxy-info",
"dirac-wms-job-submit",
"dirac-wms-job-status",
"dirac-wms-job-get-output",
]
for cmd in expected_cmds:
# Expect the commands to be in the PATH
which_cmd = shutil.which(cmd)
if not which_cmd:
log.error(f"Command '{cmd}' not found in PATH.")
return False
return True
[docs]
class JobStatus(Enum):
WAITING = "Waiting"
RUNNING = "Running"
UNKNOWN = "Unknown"
DONE = "Done"
MATCHED = "Matched"
COMPLETING = "Completing"
FAILED = "Failed"
STAGED = "Staged"
[docs]
@classmethod
def from_string(cls, value):
"""Convert string to JobStatus enum."""
value = value.strip().lower()
for status in cls:
if status.value.lower() == value:
return status
return cls.UNKNOWN
[docs]
class Tag(Enum):
PROBING = "Probing"
DEFAULT = "Default"
[docs]
class GridInterface(ABC):
def __init__(
self,
input: Union[Path, str, list[str]],
toppar_path: Path,
module_path: Path,
) -> None:
# Common attributes for all grid jobs can be defined here
# Unique name for the job
self.name = str(uuid.uuid4())
# Create a temporary directory for the job
self.loc = Path(tempfile.mkdtemp(prefix="haddock_grid_"))
# working directory where the GridJob was created, this is important
# so we know where to put the files coming from the grid
self.wd = Path.cwd()
# `id` is given by DIRAC
self.id = None
# Site where the job is running
self.site = None
# Script that will be executed in the grid
self.job_script = self.loc / "job.sh"
# JDL file that describes the job to DIRAC
self.jdl = self.loc / "job.jdl"
# Internal status of the job
self.status = JobStatus.STAGED
# `expected_outputs` is what will be produced by this `.inp` file
self.expected_outputs = []
# Output and error files generated by the job.sh execution
self.stdout_f = None
self.stderr_f = None
# List of files to be included in the payload
self.payload_fnames = []
# Counter for retries
self.retries = 0
# Tracker for timings
self.timings: dict[JobStatus, datetime] = {}
# Tag for the job
self.tag = Tag.DEFAULT
# `input_str` is the content of the `.inp` file
self.input_str = ""
# `input_str_list` is a list of separate recipes`
self.input_str_list = [] # Initialize the list to store separate recipes
# `module_path` is the path to the CNS module files
self.module_path = module_path
# `toppar_path` is the path to the TOPPAR files`
self.toppar_path = toppar_path
# Indicator if this job has been packaged
self.packaged = False
# `parse_input` is an abstract compatibility method that will handle
# different types of input. It can be either a string, a path or
# a list of paths/strings
self.parse_input(input)
[docs]
def package(self) -> None:
# We need to process the input file first to adjust paths and identify outputs
self.process_input_f()
# CREATE JOB FILE
self.create_job_script()
# CREATE THE JDL
self.create_jdl()
# CREATE PAYLOAD
self.prepare_payload(
cns_script_path=Path(self.module_path),
toppar_path=Path(self.toppar_path),
)
self.packaged = True
[docs]
@abstractmethod
def create_job_script(self) -> None:
"""Create the job script that will be executed in the grid."""
pass
[docs]
def create_jdl(self) -> None:
"""Create the JDL file that describes the job to DIRAC."""
output_sandbox = ["job.out", "job.err"]
output_sandbox.extend(self.expected_outputs)
output_sandbox_str = ", ".join(f'"{fname}"' for fname in output_sandbox)
jdl_lines = [
f'JobName = "{self.name}";',
'Executable = "job.sh";',
'Arguments = "";',
'StdOutput = "job.out";',
'StdError = "job.err";',
f'JobType = "{JOB_TYPE}";',
'InputSandbox = {"job.sh", "payload.zip"};',
"OutputSandbox = {" + f"{output_sandbox_str}" + "};",
]
jdl_string = "[\n " + "\n ".join(jdl_lines) + "\n]\n"
# log.debug(f"JDL for job {self.name}:\n{jdl_string}")
with open(self.jdl, "w") as f:
f.write(jdl_string)
[docs]
def update_status(self) -> None:
"""Update the status of the job by querying DIRAC."""
try:
result = subprocess.run(
["dirac-wms-job-status", str(self.id)],
shell=False,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as e:
log.error(
f"Updating the status failed: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}"
)
raise
output_dict = self.parse_output(result.stdout)
self.id = output_dict["JobID"]
self.status = JobStatus.from_string(output_dict["Status"])
self.site = output_dict.get("Site", "Unknown")
log.debug(self)
if self.status == JobStatus.RUNNING and JobStatus.RUNNING not in self.timings:
# job is running and we have not tracket it yet, save the time
self.timings[JobStatus.RUNNING] = datetime.now()
elif self.status == JobStatus.DONE and JobStatus.DONE not in self.timings:
# job is done and we have not tracked it yet, save the time
self.timings[JobStatus.DONE] = datetime.now()
[docs]
def prepare_payload(self, cns_script_path: Path, toppar_path: Path) -> None:
"""Prepare the payload.zip file containing all necessary files."""
# Find the CNS scripts that should be inside the payload
for f in cns_script_path.glob("*"):
self.payload_fnames.append(Path(f))
# Find the TOPPAR files that should be inside the payload
for f in toppar_path.rglob("*"):
if f.is_file(): # Only add files, not directories
self.payload_fnames.append(f)
# Create the payload.zip
with zipfile.ZipFile(f"{self.loc}/payload.zip", "w") as z:
# It must contain the CNS executable and the input file,
# z.write(self.input_f, arcname=f"{self.input_f.name}")
z.write(CNS_EXEC, arcname="cns")
for f in set(self.payload_fnames):
# NOTE: Preserve the relative path structure from `toppar_path`!
# This is important because some CNS scripts have relative paths
# hardcoded in it
if f.is_relative_to(toppar_path):
relative_path = f.relative_to(toppar_path)
z.write(f, arcname=str(relative_path))
else:
z.write(f, arcname=Path(f).name)
[docs]
def submit(self) -> None:
"""Interface to submit the job to DIRAC."""
# If this is a re-submission the timings will already be set,
# make sure it is clean
self.clean_timings()
if not self.packaged:
self.package()
self.timings[JobStatus.WAITING] = datetime.now()
try:
# The SUBMIT_CMD returns the job ID in stdout
result = subprocess.run(
["dirac-wms-job-submit", f"{self.loc}/job.jdl"],
shell=False,
capture_output=True,
text=True,
cwd=self.loc,
check=True,
)
except subprocess.CalledProcessError as e:
log.error(
f"Job submission failed: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}"
)
# TODO: Add some sort of retry/fallback mechanism here?
raise
# Add the ID to the object
self.id = int(result.stdout.split()[-1])
# Update the status
self.update_status()
[docs]
def retrieve_output(self) -> None:
"""Retrieve the output files from DIRAC.
The `dirac-wms-job-get-output` command downloads the output sandbox, this means that
anything that was specified in the `OutputSandbox` section of the JDL file will be
put in the current working directory following the pattern: `working_dir/job_id/`
"""
try:
subprocess.run(
["dirac-wms-job-get-output", str(self.id)],
shell=False,
capture_output=True,
text=True,
cwd=self.loc,
)
except subprocess.CalledProcessError as e:
log.error(
f"Retrieving output failed: {e}\nStdout: {e.stdout}\nStderr: {e.stderr}"
)
raise
# NOTE: This is the output of the `job.sh`
self.stdout_f = Path(f"{self.loc}/{self.id}/job.out")
self.stderr_f = Path(f"{self.loc}/{self.id}/job.err")
# If the job failed for some reason, save the output
if self.stderr_f.exists():
dst = Path(self.wd / f"{self.id}_dirac.err")
shutil.copy(self.stderr_f, dst)
log.debug(f"ID stderr: {self.stderr_f.read_text()}")
ls = glob.glob(f"{self.loc}/{self.id}/*")
log.debug(f"Files in the output sandbox: {ls}")
log.debug(f"Expected outputs: {self.expected_outputs}")
# Copy the output to the expected location
for output_f in self.expected_outputs:
src = Path(f"{self.loc}/{self.id}/{output_f}")
dst = Path(self.wd / f"{output_f}")
shutil.copy(src, dst)
[docs]
def clean_timings(self) -> None:
"""Clean the timings dictionary."""
self.timings = {}
[docs]
def clean(self) -> None:
"""Clean up the temporary directory where the job lives."""
shutil.rmtree(self.loc)
[docs]
@staticmethod
def parse_output(output_str: str) -> dict[str, str]:
"""Parse the output string from DIRAC commands into a dictionary."""
items = output_str.replace(";", "")
status_dict = {}
for item in items.split(" "):
if "=" in item:
key, value = item.split("=", 1)
status_dict[key.strip()] = value.strip()
return status_dict
@staticmethod
def _process_line(line: str) -> Tuple[str, Optional[str]]:
"""Process a line to identify and adjust paths."""
match_var = re.findall(VAR_PATTERN, line)
match_at = re.findall(AT_PATTERN, line)
# NOTE: In CNS it cannot match both patterns at the same time
if match_at:
item = match_at[0].strip('"').strip("'")
elif match_var:
item = match_var[0].strip('"').strip("'")
else:
# no match
return line, None
if Path(item).exists():
# This is a path
return line.replace(item, Path(item).name), item
else:
# This is not a path
return line, None
@staticmethod
def _find_output(line) -> Optional[str]:
"""Parse the line and identify if this contains an output file declaration."""
match = re.search(OUTPUT_PATTERN, line)
if match:
return match.group(1) if match.group(1) else match.group(2)
return None
def __repr__(self) -> str:
return f"ID: {self.id} Name: {self.name} Output: {self.expected_outputs} Status: {self.status.value} Site: {self.site}"
[docs]
class GridJob(GridInterface):
"""GridJob is a class tha represents a job to be run on the GRID via DIRAC."""
def __init__(
self,
input: Union[Path, str],
toppar_path: Path,
module_path: Path,
) -> None:
super().__init__(
input=input,
toppar_path=toppar_path,
module_path=module_path,
) # initialize the base class
[docs]
def create_job_script(self) -> None:
"""Create the job script that will be executed in the grid."""
# NOTE: We use `\n` instead of `os.linesep` because this will
# be executed in a Linux environment inside the grid
inp_name = f"{self.name}.inp"
cns_out_name = f"{self.name}.out"
self.expected_outputs.append(cns_out_name)
instructions = "#!/bin/bash\n"
instructions += "export MODULE=./\n"
instructions += "export TOPPAR=./\n"
instructions += "unzip payload.zip\n"
instructions += f"./cns < {inp_name} > {cns_out_name}\n"
# Remove `cns.log` if there is no error
instructions += f"[ $? -eq 0 ] && rm {cns_out_name} || exit 1\n"
with open(self.job_script, "w") as f:
f.write(instructions)
[docs]
class CompositeGridJob(GridInterface):
def __init__(
self,
input: list[str],
toppar_path: Path,
module_path: Path,
) -> None:
super().__init__(
input=input,
toppar_path=toppar_path,
module_path=module_path,
) # initialize the base class
[docs]
def create_job_script(self) -> None:
"""Create the job script that will be executed in the grid."""
# NOTE: We use `\n` instead of `os.linesep` because this will
# be executed in a Linux environment inside the grid
instructions = "#!/bin/bash\n"
instructions += "export MODULE=./\n"
instructions += "export TOPPAR=./\n"
instructions += "unzip payload.zip\n"
for idx, _ in enumerate(self.input_str_list):
inp_name = f"{idx}_{self.name}.inp"
cns_out_name = f"{idx}_{self.name}.out"
self.expected_outputs.append(cns_out_name)
instructions += f"./cns < {inp_name} > {cns_out_name}\n"
instructions += f"[ $? -eq 0 ] && rm {cns_out_name} || exit 1\n"
with open(self.job_script, "w") as f:
f.write(instructions)
[docs]
class GRIDScheduler:
"""Scheduler to manage and run jobs on the GRID via DIRAC."""
def __init__(
self, tasks: list[CNSJob], params: dict, probing: float = 0.05
) -> None:
self.probing: bool = True
self.ncores = parse_ncores(params["ncores"])
self.batch_size = 1
self.workload: list[GridJob] = [
GridJob(
input=t.input_file,
toppar_path=t.envvars["TOPPAR"],
module_path=t.envvars["MODULE"],
)
for t in tasks
]
# ===============================================================#
#
# ! IMPORTANT !
#
# The `subset_size` is the number of jobs to be used for probing the grid
# In theory we could send any number of jobs, but in practice we need to
# consider that the submission itself takes time. So we can come to the
# scenario in which we spend more time submitting jobs than measuring,
# this means that we would not be able to properly evaluate the grid
# capacity.
# Here we set the `subset_size` to be the minimum between `ncores`
# and the 5% of the total number of jobs. By setting it to the number
# of cores, we can be sure that we will be able to measure the metrics.
# We also need to ensure there is at least one job to probe the grid.
#
# HACK: This is not ideal, but it is a compromise between accuracy and speed.
subset_size = min(self.ncores, max(1, math.ceil(len(self.workload) * probing)))
# ===============================================================#
# Randomly select that many jobs
if len(self.workload) >= subset_size:
for i in random.sample(range(len(self.workload)), subset_size):
self.workload[i].tag = Tag.PROBING
else:
log.warning("> Not enough jobs to probe the grid, skipping probing step <")
self.probing = False
[docs]
def run(self) -> None:
"""Execute the tasks."""
log.info("#" * 42)
log.info("=== Running tasks with GRID Scheduler ===")
self.probe_grid_efficiency()
self.create_batches()
self.submit_jobs()
self.wait_for_completion()
log.info("#" * 42)
[docs]
def create_batches(self) -> None:
"""Create batches of jobs to be submitted together."""
log.info("++ Concatenating jobs to increase efficiency...")
jobs = [j for j in self.workload if j.tag == Tag.DEFAULT]
toppar_path = jobs[0].toppar_path
module_path = jobs[0].module_path
composite_jobs = []
for i, batch in enumerate(range(0, len(jobs), self.batch_size), start=1):
input = [j.input_str for j in jobs[batch : batch + self.batch_size]]
log.debug(f" Payload {i}, n={len(input)} job(s)")
job = CompositeGridJob(
input=input,
toppar_path=toppar_path,
module_path=module_path,
)
composite_jobs.append(job)
log.info(
f"++ Created {len(composite_jobs)} payload(s) with up to {self.batch_size} job(s) each"
)
self.workload = composite_jobs
[docs]
def wait_for_completion(self) -> None:
"""Wait for jobs with status WAITING or RUNNING to complete."""
log.info("++ Waiting...")
complete = False
while not complete:
jobs_to_check = [
job
for job in self.workload
if job.status
not in {JobStatus.STAGED, JobStatus.DONE, JobStatus.FAILED}
]
if jobs_to_check:
log.debug(f"+ Checking status of {len(jobs_to_check)} payload(s)...")
with ThreadPoolExecutor(max_workers=self.ncores) as executor:
executor.map(self.process_job, jobs_to_check)
else:
complete = True
[docs]
def submit_jobs(self, tag: Tag = Tag.DEFAULT) -> None:
"""Submit jobs to the GRID in parallel."""
queue = [
job
for job in self.workload
if job.tag == tag and job.status == JobStatus.STAGED
]
log.info(f"++ Submitting {len(queue)} '{tag.value}' payloads to the grid...")
with ThreadPoolExecutor(max_workers=self.ncores) as executor:
executor.map(lambda job: job.package(), queue)
with ThreadPoolExecutor(max_workers=self.ncores) as executor:
executor.map(lambda job: job.submit(), queue)
[docs]
def probe_grid_efficiency(self) -> None:
"""Submit a small number of jobs to probe the efficiency of the GRID."""
if not self.probing:
return
log.info("++ Probing grid efficiency...")
# Submit
self.submit_jobs(tag=Tag.PROBING)
# Wait
self.wait_for_completion()
# Calculate actual durations from timestamps
waiting_durations = []
running_durations = []
for job in self.workload:
wait_start = job.timings.get(JobStatus.WAITING)
run_start = job.timings.get(JobStatus.RUNNING)
done_time = job.timings.get(JobStatus.DONE)
if wait_start is None or run_start is None or done_time is None:
continue # Skip jobs without complete timing info
else:
log.debug(f"Job {job.expected_outputs}")
log.debug(f" timings: {job.timings}")
# Calculate waiting duration (WAITING to RUNNING)
waiting_duration = (run_start - wait_start).total_seconds()
waiting_durations.append(waiting_duration)
# Calculate running duration (RUNNING to DONE)
running_duration = (done_time - run_start).total_seconds()
running_durations.append(running_duration)
if not running_durations or not waiting_durations:
log.warning(
"> Average running time is zero, cannot calculate optimal batch size <"
)
return
# Calculate average durations
avg_waiting = sum(waiting_durations) / len(waiting_durations)
avg_running = sum(running_durations) / len(running_durations)
target_efficiency = 0.9
batch_size = self.calculate_optimal_batch_size(
N=self.batch_size,
W=avg_waiting,
R=avg_running,
T=target_efficiency,
)
# Make sure batch size is not larger than the number of default jobs
# TODO: Add a warning if this happens?
batch_size = min(
batch_size, len([j for j in self.workload if j.tag == Tag.DEFAULT])
)
self.batch_size = batch_size
[docs]
@staticmethod
def calculate_optimal_batch_size(N: int, W: float, R: float, T: float) -> int:
"""Calculate the optimal batch size to achieve target efficiency."""
# The efficiency of a given batch can be described as:
#
# E = N * R / W + N * R
#
# Where E is efficiency, N is the number of jobs running at the same time
# R is the average running time, W is the average waiting time
#
# The current efficiency is then:
E = (N * R) / (W + N * R)
log.info(f"+ Current efficiency with {N} job(s) per payload: {E:.1%}")
# So to achieve a target efficiency T
# We solve for N, which is the batch size:
# E = N * R / W + N * R
# T * (W + N * R) = N * R
# T * W + T * N * R = N * R
# T * W = N * R - T * N * R
# T * W = N * R * (1 - T)
# N = T * W / R * (1 - T)
batch_size = (T * W) / (R * (1 - T))
batch_size = max(1, round(batch_size)) # Ensure at least 1 job
return batch_size
[docs]
@staticmethod
def process_job(job: GridJob) -> None:
"""Process a single job: update status, retrieve output if done, handle retries if failed.
NOTE: This function is parallelized, that is why things like cleaning, download, retry
are here. If you are adding new functionality, consider if it should be here or in the
sequential part of the code.
"""
job.update_status()
if job.status == JobStatus.FAILED:
# Jobs on the grid can fail for many reasons outside our control
# So if the job is failed, resubmit it up to MAX_RETRIES times
if job.retries < MAX_RETRIES:
expected_output_str = ",".join(job.expected_outputs)
job.retries += 1
log.warning(
f"> Job {job.name} ({expected_output_str}) failed on {job.site}, re-submitting - {job.retries}/{MAX_RETRIES} <"
)
log.debug(f"job {job.name} at {job.loc}")
job.submit()
# TODO: Add some sort of fallback mechanism here if it reaches the MAX_RETRIES
if job.status == JobStatus.DONE:
log.debug(f"Job {job.name} is done, retrieving output...")
job.retrieve_output()
log.debug(f"job {job.name} at {job.loc}")
job.clean()