Source code for haddock.modules.analysis.caprifilter.caprifilter

"""Helper functions for the caprifilter module."""

import os
from math import isnan
from pathlib import Path

from haddock.libs.libontology import PDBFile


# All metric attribute names as they appear on the CAPRI class.
METRIC_NAMES = ("irmsd", "lrmsd", "ilrmsd", "fnat", "dockq", "rmsd")

# Mapping from filter_by to metric name, so relevant computations are performed in libcapri.
_METRIC_TO_PARAM: dict[str, str] = {
    "irmsd": "irmsd",
    "lrmsd": "lrmsd",
    "ilrmsd": "ilrmsd",
    "fnat": "fnat",
    "dockq": "dockq",
    "rmsd": "global_rmsd",
}

# Metrics that dockq computation depends on.
_DOCKQ_DEPS: tuple[str, ...] = ("fnat", "irmsd", "lrmsd")

# Public set of valid metric names accepted in filter_by.
VALID_FILTER_METRICS: frozenset[str] = frozenset(_METRIC_TO_PARAM)


[docs] def get_capri_params(filter_by: list[str]) -> dict[str, bool]: """Return CAPRI computation params required for the given filter metrics. Translates the user-facing ``filter_by`` metric names into the boolean computation flags expected by the CAPRI class. All six computation params are returned: those required by ``filter_by`` are set to True, the rest to False. If ``dockq`` is requested, its dependencies (fnat, irmsd, lrmsd) are also enabled automatically. Parameters ---------- filter_by : list[str] List of metrics to filter on. Valid values: irmsd, lrmsd, ilrmsd, fnat, dockq, rmsd. Returns ------- params : dict[str, bool] Computation param flags ready to be injected into ``self.params`` before CAPRI jobs are created. Keys: irmsd, lrmsd, ilrmsd, fnat, dockq, global_rmsd. """ params: dict[str, bool] = {p: False for p in _METRIC_TO_PARAM.values()} for metric in filter_by: params[_METRIC_TO_PARAM[metric]] = True if metric == "dockq": for dep in _DOCKQ_DEPS: params[dep] = True return params
[docs] def collect_metrics(capri_objects: list) -> dict[PDBFile, dict[str, float]]: """Return {model: {metric: value}} for all CAPRI metrics. Parameters ---------- capri_objects : list[CAPRI] List of CAPRI objects, one per model (if multiple references given - one best reference is already selected). Returns ------- metrics_data : dict[PDBFile, dict[str, float]] Dictionary mapping each model's PDBFile to a dict of metric values. """ return { capri.model: {m: getattr(capri, m) for m in METRIC_NAMES} for capri in capri_objects }
[docs] def filter_models( metrics_data: dict[PDBFile, dict[str, float]], filter_specs: dict[str, tuple[float, str]], ) -> tuple[list[PDBFile], list[PDBFile]]: """Split models into kept and filtered_out based on multiple metric filters. All filters are applied simultaneously with AND logic. A model must pass every active filter to be kept. Models with NaN for any filtered metric are always removed. Parameters ---------- metrics_data : dict[PDBFile, dict[str, float]] Metric values per model, as returned by collect_metrics. filter_specs : dict[str, tuple[float, str]] {metric: (cutoff, filter_out)} where filter_out is 'above' or 'below'. 'above': filter out models with value > cutoff (keep value <= cutoff). 'below': filter out models with value < cutoff (keep value >= cutoff). Returns ------- kept, filtered_out : tuple[list[PDBFile], list[PDBFile]] 'filtered_out' is not used in the caprifilter module, it is kept for tests. """ kept: list[PDBFile] = [] filtered_out: list[PDBFile] = [] for model, metric_vals in metrics_data.items(): passes = True for metric, (cutoff, filter_out) in filter_specs.items(): value = metric_vals.get(metric, float("nan")) if isnan(value): passes = False break if filter_out == "above" and value > cutoff: passes = False break if filter_out == "below" and value < cutoff: passes = False break (kept if passes else filtered_out).append(model) return kept, filtered_out
[docs] def write_caprifilter( kept: list[PDBFile], capri_objects: list, filter_specs: dict[str, tuple[float, str]], fname: str = "caprifilter.tsv", ) -> None: """Write TSV with kept models and only the user-requested metric columns. Parameters ---------- kept : list[PDBFile] Models that passed all filters. capri_objects : list[CAPRI] CAPRI objects (one per model, best reference already selected). filter_specs : dict[str, tuple[float, str]] Active filter specifications {metric: (cutoff, filter_out)}. fname : str Output file name. """ capri_by_model: dict[int, object] = { id(capri.model): capri for capri in capri_objects } metrics = list(filter_specs.keys()) filter_parts = [] for metric, (cutoff, filter_out) in filter_specs.items(): op = ">" if filter_out == "above" else "<" filter_parts.append(f"{metric}{op}{cutoff:.3f}") filter_summary = ", ".join(filter_parts) if filter_parts else "none" def fmt(v): return ( "nan" if (v is None or (isinstance(v, float) and isnan(v))) else f"{v:.3f}" ) with open(fname, "w") as fh: fh.write( f"# caprifilter: filter=[{filter_summary}]; " f"{len(kept)} model(s) kept." f"{os.linesep}" ) fh.write("\t".join(["model", "score"] + metrics) + os.linesep) for model in kept: capri = capri_by_model.get(id(model)) score_str = fmt(capri.score if capri is not None else float("nan")) metric_strs = [ fmt(getattr(capri, m) if capri is not None else float("nan")) for m in metrics ] fh.write( "\t".join([str(model.rel_path), score_str] + metric_strs) + os.linesep )
[docs] def write_caprifilter_full( capri_objects: list, metrics_data: dict[PDBFile, dict[str, float]], kept: list[PDBFile], filter_specs: dict[str, tuple[float, str]], fname: str = "caprifilter_all_models.tsv", ) -> None: """Write TSV with all models, user-requested metric columns, and a status column. Parameters ---------- capri_objects : list[CAPRI] CAPRI objects (one per model, best reference already selected). metrics_data : dict[PDBFile, dict[str, float]] Metric values per model, as returned by collect_metrics. kept : list[PDBFile] Models that passed all filters. filter_specs : dict[str, tuple[float, str]] Active filter specifications {metric: (cutoff, filter_out)}. fname : str Output file name. """ kept_set = set(id(m) for m in kept) metrics = list(filter_specs.keys()) # index CAPRI objects by model identity for score lookup capri_by_model: dict[int, object] = { id(capri.model): capri for capri in capri_objects } n_kept = len(kept) n_total = len(metrics_data) pct_filtered = (1 - n_kept / n_total) * 100 if n_total else 0.0 filter_parts = [] for metric, (cutoff, filter_out) in filter_specs.items(): op = ">" if filter_out == "above" else "<" filter_parts.append(f"{metric}{op}{cutoff:.3f}") filter_summary = ", ".join(filter_parts) if filter_parts else "none" def fmt(v): return ( "nan" if (v is None or (isinstance(v, float) and isnan(v))) else f"{v:.3f}" ) with open(fname, "w") as fh: fh.write( f"# caprifilter: filter=[{filter_summary}]; " f"{n_kept} model(s) kept; {pct_filtered:.2f}% filtered out." f"{os.linesep}" ) fh.write("\t".join(["model", "score"] + metrics + ["status"]) + os.linesep) for model, metric_vals in metrics_data.items(): capri = capri_by_model.get(id(model)) score_str = fmt(capri.score if capri is not None else float("nan")) status = "kept" if id(model) in kept_set else "filtered" metric_strs = [fmt(metric_vals[m]) for m in metrics] fh.write( "\t".join([str(model.rel_path), score_str] + metric_strs + [status]) + os.linesep )