Source code for haddock.modules.analysis.caprifilter

"""Filter models by any combination of CAPRI metrics computed against a reference structure.

The user specifies a list of metrics via ``filter_by`` (e.g. ``['irmsd', 'fnat']``).
For each metric, a per-metric cutoff and filter direction are set independently:

- ``{metric}_filter_cutoff``: the numerical threshold.
- ``{metric}_filter_out``: ``above`` removes models where value > cutoff (keep ≤);
  ``below`` removes models where value < cutoff (keep ≥).

All active filters are applied simultaneously (AND logic), a model must pass
every filter to be kept. The following metrics are supported:

- RMSD
- iRMSD (interface RMSD)
- lRMSD (ligand RMSD)
- ilRMSD (interface ligand RMSD)
- FNAT (fraction of native contacts)
- DOCKQ

The following files are generated:

- **caprifilter.tsv**: kept models with score and the user-requested metric
  columns only.
- **caprifilter_all_models.tsv**: all input models, user-requested metrics,
  and a ``status`` column (``kept`` / ``filtered``). Written only when
  ``caprifilter_full = true``.
- **caprifilter_ss.tsv**: caprieval-style ranked table of all models.
- **caprifilter_multiref.tsv**: when multiple references are provided,
  one row per (model, reference) pair.

For more details about this module, please `refer to the haddock3 user manual
<https://www.bonvinlab.org/haddock3-user-manual/modules/analysis.html#caprifilter-module>`_
"""

from math import isnan
from pathlib import Path

from haddock.core.defaults import MODULE_DEFAULT_YAML
from haddock.core.typing import FilePath, Union
from haddock.libs.libaa2cg import martinize
from haddock.libs.libcapri import (
    CAPRI,
    extract_data_from_capri_class,
    extract_models_best_references,
)
from haddock.libs.libontology import PDBFile
from haddock.libs.libparallel import Scheduler
from haddock.libs.libpdb import handle_input_reference
from haddock.libs.libstructure import find_ff
from haddock.modules import BaseHaddockModule
from haddock.modules.analysis.caprifilter.caprifilter import (
    VALID_FILTER_METRICS,
    collect_metrics,
    filter_models,
    get_capri_params,
    write_caprifilter,
    write_caprifilter_full,
)


RECIPE_PATH = Path(__file__).resolve().parent
DEFAULT_CONFIG = Path(RECIPE_PATH, MODULE_DEFAULT_YAML)


[docs] class HaddockModule(BaseHaddockModule): """HADDOCK3 module to filter models by CAPRI metrics.""" name = RECIPE_PATH.name def __init__( self, order: int, path: Path, init_params: FilePath = DEFAULT_CONFIG, ) -> None: super().__init__(order, path, init_params)
[docs] @classmethod def confirm_installation(cls) -> None: """Confirm if module is installed.""" return
[docs] @staticmethod def is_nested(models: list[Union[PDBFile, list[PDBFile]]]) -> bool: for model in models: if isinstance(model, list): return True return False
def _run(self) -> None: """Execute module.""" # Get models from previous step models = self.previous_io.retrieve_models(individualize=True) if self.is_nested(models): raise ValueError( "[caprifilter] cannot be executed after modules that produce " "a nested list of models." ) if not self.params["reference_fname"]: self.finish_with_error( "[caprifilter] No reference structure provided! " "Please set 'reference_fname' to a valid PDB file." ) filter_by: list[str] = self.params["filter_by"] if not filter_by: self.finish_with_error( "[caprifilter] filter_by is empty, please specify at least one metric." ) unknown = [m for m in filter_by if m not in VALID_FILTER_METRICS] if unknown: self.finish_with_error( f"[caprifilter] Unknown metric(s) in filter_by: {unknown}. " f"Valid choices are: {sorted(VALID_FILTER_METRICS)}." ) self.params.update(get_capri_params(filter_by)) # Build filter_specs: {metric: (cutoff, filter_out)} filter_specs: dict[str, tuple[float, str]] = { metric: ( self.params[f"{metric}_filter_cutoff"], self.params[f"{metric}_filter_out"], ) for metric in filter_by } # Build reference list reference = Path(self.params["reference_fname"]) references = handle_input_reference(reference) # Handle coarse-grain force field ff = find_ff(models) if ff == "martini2": references = [ Path(martinize(ref, self.path.resolve().parent, False)) for ref in references ] # Create one CAPRI job per (model, reference) pair jobs: list[CAPRI] = [] for i, model in enumerate(models, start=1): for ref_id, ref in enumerate(references, start=1): jobs.append( CAPRI( identificator=i, model=model, path=Path("."), reference=ref, params=self.params, ref_id=ref_id, ff=ff, ) ) engine = Scheduler( tasks=jobs, ncores=self.params["ncores"], max_cpus=self.params["max_cpus"], ) engine.run() jobs = engine.results jobs = sorted(jobs, key=lambda c: c.identificator) # Best reference per model best_ref_jobs = extract_models_best_references(jobs) # Warn if models carry cluster assignments — filtering ignores them, # so surviving models may have stale clt_id/clt_rank values. if any(getattr(m, "clt_id", None) is not None for m in models): self.log( "Warning: input models have cluster assignments (clt_id/clt_rank). " "[caprifilter] filters on a per-model basis and does not update cluster info. " "Remaining models will retain their original cluster labels." ) # Write full metrics table (same format as caprieval) extract_data_from_capri_class( capri_objects=best_ref_jobs, output_fname=Path(".", "caprifilter_ss.tsv"), sort_key=self.params["sortby"], sort_ascending=self.params["sort_ascending"], add_reference_id=len(references) > 1, ) if len(references) > 1: extract_data_from_capri_class( capri_objects=jobs, output_fname=Path(".", "caprifilter_multiref.tsv"), sort_key=self.params["sortby"], sort_ascending=self.params["sort_ascending"], add_reference_id=True, ) # Collect metrics and apply all filters metrics_data = collect_metrics(best_ref_jobs) kept, _ = filter_models(metrics_data, filter_specs) # Log NaN counts per filtered metric for metric in filter_by: n_nan = sum(1 for vals in metrics_data.values() if isnan(vals[metric])) if n_nan: self.log( f"{100 * n_nan / len(models):6.2f}% of models had NaN {metric} " "(alignment failed) and will be excluded." ) if not metrics_data or all( all(isnan(v) for v in vals.values()) for vals in metrics_data.values() ): self.finish_with_error( "[caprifilter] All models have NaN metrics, i.e. alignment failed for every model." ) if not kept: specs_str = ", ".join( f"{m} filter_out={fo} cutoff={c:.3f}" for m, (c, fo) in filter_specs.items() ) self.finish_with_error( f"[caprifilter] With filters [{specs_str}], " "ALL models were filtered out!!" ) pct_filtered = (1 - len(kept) / len(models)) * 100 specs_str = ", ".join( f"{m} filter_out={fo} cutoff={c:.3f}" for m, (c, fo) in filter_specs.items() ) self.log( f"Filters: [{specs_str}] — " f"{pct_filtered:.2f}% filtered out, {len(kept)} model(s) passed." ) # Write clean output: kept models, requested metric columns only write_caprifilter( kept=kept, capri_objects=best_ref_jobs, filter_specs=filter_specs, ) # Write full status table if requested if self.params["caprifilter_full"]: write_caprifilter_full( capri_objects=best_ref_jobs, metrics_data=metrics_data, kept=kept, filter_specs=filter_specs, ) # Send models to the next step, no operation is done on them self.output_models = kept # type: ignore # ignore this here only if we are checking the return type of `retrieve_models` is not nested!! self.export_io_models()