Source code for Capricho.cli.chembl_data_pipeline

from inspect import signature
from pathlib import Path
from typing import Literal, Optional, Union

import pandas as pd
from chemFilters.chem.standardizers import ChemStandardizer, InchiHandling
from job_tqdflex import ParallelApplier
from rdkit import Chem

from ..chembl.data_flag_functions import (
    flag_censored_activity_comment,
    flag_insufficient_assay_overlap,
    flag_inter_document_duplication,
    flag_max_assay_size,
    flag_min_assay_size,
    flag_missing_canonical_smiles,
    flag_missing_document_date,
    flag_missing_standard_smiles,
    flag_salt_or_solvent_removal,
    flag_strict_mutant_assays,
    flag_to_remove_mixture_compounds,
    flag_undefined_stereochemistry,
    flag_unit_conversion,
)
from ..chembl.exceptions import BioactivitiesNotFoundError
from ..chembl.processing import get_bioactivities_workflow
from ..chembl.unit_conversions import (
    convert_dose_units,
    convert_mass_concentration_units,
    convert_molar_concentration_units,
    convert_permeability_units,
    convert_time_units,
)
from ..core.default_fields import (
    ASSAY_ID,
    DATA_DROPPING_COMMENT,
    DATA_PROCESSING_COMMENT,
    MOLECULE_ID,
    TARGET_ID,
)
from ..core.fp_utils import calculate_mixed_FPs
from ..core.pandas_helper import save_dataframe
from ..core.smiles_utils import clean_mixtures
from ..core.stats_make import process_repeat_mols, repeated_indices_from_array_series
from ..core.stereo import find_undefined_stereocenters
from ..logger import logger

# when aggregated, some `activity_id` values will be strings and sorting won't work properly
AGGREGATE_SAVE_SORTED_BY = ["target_chembl_id", "assay_chembl_id"]


def _count_flags(df: pd.DataFrame, column: str) -> dict[str, int]:
    """Count individual flags in an & -separated flag column.

    Splits each cell by " & " and normalizes dynamic patterns (e.g., "Assay size < 20"
    becomes "Assay size <") before counting.

    Args:
        df: DataFrame containing the flag column.
        column: Name of the column to count flags from.

    Returns:
        Dict mapping normalized flag patterns to their counts.
    """
    from ..analysis import normalize_comment_pattern

    counts: dict[str, int] = {}
    for cell in df[column].fillna("").astype(str):
        if not cell or cell == "nan":
            continue
        for flag in cell.split(" & "):
            flag = flag.strip()
            if not flag:
                continue
            normalized = normalize_comment_pattern(flag)
            counts[normalized] = counts.get(normalized, 0) + 1
    return counts


def _log_pipeline_summary(
    df: pd.DataFrame,
    pre_aggregation_count: int,
    post_aggregation_count: int,
) -> None:
    """Log a structured summary of the pipeline run.

    Args:
        df: The pre-aggregation DataFrame (one row per measurement).
        pre_aggregation_count: Total rows fetched before any processing.
        post_aggregation_count: Rows after aggregation.
    """
    lines = ["", "PIPELINE SUMMARY"]

    lines.append(f"  Rows fetched:              {pre_aggregation_count:>8,}")
    lines.append(f"  Rows after aggregation:    {post_aggregation_count:>8,}")

    if len(df) > 0:
        # Pre-aggregation df uses molecule_chembl_id; post-aggregation uses connectivity
        cpd_col = "connectivity" if "connectivity" in df.columns else MOLECULE_ID
        n_compounds = df[cpd_col].nunique() if cpd_col in df.columns else 0
        n_targets = df[TARGET_ID].nunique() if TARGET_ID in df.columns else 0
        n_assays = df[ASSAY_ID].nunique() if ASSAY_ID in df.columns else 0
        lines.append(f"  Unique compounds:          {n_compounds:>8,}")
        lines.append(f"  Unique targets:            {n_targets:>8,}")
        lines.append(f"  Unique assays:             {n_assays:>8,}")

    total = len(df)

    # Quality flags (data_dropping_comment)
    if DATA_DROPPING_COMMENT in df.columns and total > 0:
        drop_counts = _count_flags(df, DATA_DROPPING_COMMENT)
        lines.append("")
        lines.append("  QUALITY FLAGS (data_dropping_comment)")
        if drop_counts:
            for flag, count in sorted(drop_counts.items(), key=lambda x: -x[1]):
                pct = count / total * 100
                lines.append(f"    {flag + ':':<45s} {count:>6,}  ({pct:5.1f}%)")
        else:
            lines.append("    (none)")

    # Processing flags (data_processing_comment)
    if DATA_PROCESSING_COMMENT in df.columns and total > 0:
        proc_counts = _count_flags(df, DATA_PROCESSING_COMMENT)
        lines.append("")
        lines.append("  PROCESSING FLAGS (data_processing_comment)")
        if proc_counts:
            for flag, count in sorted(proc_counts.items(), key=lambda x: -x[1]):
                pct = count / total * 100
                lines.append(f"    {flag + ':':<45s} {count:>6,}  ({pct:5.1f}%)")
        else:
            lines.append("    (none)")

    logger.info("\n".join(lines))


# after the workflow, `activity_id` is an integer, so we can sort by it to ensure consistent
# ordering on the aggregated datapoints -> assay1|assay2|...|assayN,activity_id1|...|activity_idN
WORKFLOW_SAVE_SORTED_BY = [*AGGREGATE_SAVE_SORTED_BY, "activity_id"]


def _warn_info_post_aggregation_repeats(
    df: pd.DataFrame,
    extra_id_cols: list[str],
    aggregate_mutants: bool = False,
    value_col: str = "pchembl_value",
    _limit: int = 15,  # limit in the string length for the warning/info logging
) -> None:
    def _truncate_dataframe(df: pd.DataFrame, limit: int) -> pd.DataFrame:
        """Truncate DataFrame values to a specified length."""
        if pd.__version__ > "2.1.0":  # applymap got deprecated in 2.1.0
            return df.map(lambda x: str(x)[:limit] + "..." if len(str(x)) > limit else str(x))
        else:
            return df.applymap(lambda x: str(x)[:limit] + "..." if len(str(x)) > limit else str(x))

    if aggregate_mutants:
        col_subset_dupli_warning = ["connectivity", "target_chembl_id", *extra_id_cols]
    else:
        col_subset_dupli_warning = ["connectivity", "mutation", "target_chembl_id", *extra_id_cols]

    value_mean_col = f"{value_col}_mean"
    logging_subset = [  # subset of columns to be displayed on the warning/info logging
        *col_subset_dupli_warning,
        "molecule_chembl_id",
        "assay_chembl_id",
        value_mean_col,
    ]

    # Based on the ID columns, we shouldn't have any duplicates. This warning is a safeguard
    duplics_for_warning = df.duplicated(subset=col_subset_dupli_warning)
    if duplics_for_warning.any():
        dupli_subset = (
            df[duplics_for_warning]
            .loc[:, logging_subset]
            .sort_values(
                by=["target_chembl_id", "connectivity", value_mean_col],
                ascending=[True, True, False],
            )
        )
        truncated_df = _truncate_dataframe(dupli_subset, _limit)
        logger.warning(
            f"There two or more compounds matching the ID columns {col_subset_dupli_warning} "
            "This is not intentional, please further inspect the collected dataset. Here's a sample "
            "of the repeated entries:\n"
            f"{truncated_df.head(10).to_string(index=False)}"
        )

    # Additional safeguard to ensure proper handling of the output by the user prior to modeling
    target_cpd_col_subset = ["target_chembl_id", "connectivity"]
    duplics_for_info = df.duplicated(subset=target_cpd_col_subset)
    if duplics_for_info.any():
        dupli_subset = (
            df[duplics_for_info]
            .loc[:, logging_subset]
            .sort_values(by=target_cpd_col_subset + [value_mean_col], ascending=[True, True, False])
        )
        truncated_df = _truncate_dataframe(dupli_subset, _limit)
        logger.info(
            "There are two or more repeated compound-target readouts (based on `connectivity` & `target_chembl_id`) "
            "without considering other ID columns. This is a result of your aggregation criteria. Make "
            "sure to differ these data points in your modeling pipeline by including information of your other id_columns, "
            "or resolve these compound-target repeats prior to modeling. Here's a sample of the repeated entries:\n"
            f"{truncated_df.head(10).to_string(index=False)}"
        )

    return


[docs] def get_standardize_and_clean_workflow( molecule_ids: Optional[list[str]] = None, target_ids: Optional[list[str]] = None, assay_ids: Optional[list[str]] = None, document_ids: Optional[list[str]] = None, chirality: bool = True, calculate_pchembl: bool = False, output_path: Optional[Union[str, Path]] = None, confidence_scores: list[str] = [7, 8, 9], bioactivity_type: Optional[list[str]] = None, standard_relation: list[str] = ["="], standard_units: Optional[list[str]] = None, assay_types: list[str] = ["B", "F"], chembl_release: Optional[int] = None, save_not_aggregated: bool = True, drop_unassigned_chiral: bool = False, version: Optional[Union[int, str]] = None, backend: Literal["downloader", "webresource"] = "downloader", curate_annotation_errors: bool = True, require_doc_date: bool = False, min_assay_size: Optional[int] = None, max_assay_size: Optional[int] = None, min_assay_overlap: int = 0, strict_mutant_removal: bool = False, value_col: str = "pchembl_value", enable_unit_conversion: bool = False, ) -> pd.DataFrame: # Changed return type annotation to pd.DataFrame """Fetched the filtered data from ChEMBL based on the provided IDs, assay confidence, and bioactivity types. The fetched smiles are then standardized and chemical mixtures are removed from the dataset. Duplicate data is also removed and the remaining data is saved to a csv file. Args: molecule_ids: list of ChEMBL molecule IDs to filter data from target_ids: list of ChEMBL target IDs to filter data from assay_ids: list of ChEMBL assay IDs to filter data from document_ids: list of ChEMBL document IDs to filter data from calculate_pchembl: whether to calculate pchembl values when not found for assay results reported in nanomolar/micromolar units chirality: setting this to False will remove stereochemistry information from the SMILES on top of the standardization. Defaults to True. output_path: path to save the resulting csv file confidence_scores: list of confidence scores (assay-related) to filter data from bioactivity_type: list of bioactivity types (assay-related) to filter data from standard_relation: standard relation to filter data from. Currently only supports "=" Defaults to "=". chembl_release: latest ChEMBL release to retrieve data from save_not_aggregated: whether to save the resulting data to the csv (output_path) before drop_unassigned_chiral: whether to drop data points with undefined stereocenters. Defaults to False. version: `backend=="downloader"` only! version of the ChEMBL database to be downloaded by chembl_downloader. If left as None, the latest version will be downloaded. Defaults to None. backend: the backend to be used for fetching the data. If downloader, the ChEMBL sql database is downloaded and extracted first. Defaults to "downloader". curate_annotation_errors: Whether to apply activity curation based on pChEMBL values diverging in exactly 3.0 (indicate possible annotation errors). Defaults to True. require_doc_date: Whether to filter out activities without a document year. max_assay_size: Minimum number of compounds in an assay. Assays smaller than this size will have their activities flagged for removal. Defaults to None (no filtering). max_assay_size: Maximum number of compounds in an assay. Assays exceeding this size will have their activities flagged for removal. Defaults to None (no filtering). min_assay_overlap: Minimum number of overlapping compounds between two assays for the same target for their activities to be considered. Defaults to 0 (no filtering). strict_mutant_removal: If True, assays with 'mutant', 'mutation', or 'variant' in their description will be flagged for removal. Defaults to False. Returns: pd.DataFrame: the filtered, standardized, and cleaned data """ if output_path is not None: if isinstance(output_path, str): output_path = Path(output_path) # -log | log transformed values reported as Log XC50, -Log XC50, etc, might not # have a pchembl value, but *could* still be used. If standard_type contains `Log`, # the standard_value will be transferred to pchembl_value. if bioactivity_type is None: # No filter on standard_type - fetch all biotypes = None elif calculate_pchembl: biotypes = [] for act in bioactivity_type: biotypes.extend([f"Log {act}", f"-Log {act}", act]) else: if standard_relation != ["="]: logger.error( "pchembl_values are only calculated for standard_relation='='. If you want " "to use censored data, please set `calculate_pchembl` to True with the flag " "--calculate-pchembl." ) biotypes = bioactivity_type # get_bioactivities_workflow -> fetch with either webresource or downloader -> (minimally) process bioactivities # -> standardization is done here -> curate bioactivity errors (if curate_annotation_errors=True) full_df = get_bioactivities_workflow( molecule_chembl_ids=molecule_ids or None, target_chembl_ids=target_ids or None, assay_chembl_ids=assay_ids or None, document_chembl_ids=document_ids or None, confidence_scores=confidence_scores, assay_types=assay_types, standard_relation=standard_relation, standard_type=biotypes, standard_units=standard_units, calculate_pchembl=calculate_pchembl, curate_annotation_errors=curate_annotation_errors, require_document_date=require_doc_date, chembl_release=chembl_release, version=version, backend=backend, value_col=value_col, ) # Flag activities without document dates for transparency # Note: if require_doc_date=True, these will be hard-filtered in process_bioactivities full_df = flag_missing_document_date(full_df) # Correct censored activity comments (inactive/inconclusive) with incorrect standard_relation='=' full_df = flag_censored_activity_comment(full_df) # Convert units if requested if enable_unit_conversion: logger.info("Converting units to standard formats") full_df = convert_permeability_units( # Convert to 10^-6 cm/s full_df, value_col=value_col, unit_col="standard_units", ) full_df = flag_unit_conversion(full_df) full_df = convert_molar_concentration_units( # Convert to nM full_df, value_col=value_col, unit_col="standard_units", ) full_df = flag_unit_conversion(full_df) full_df = convert_mass_concentration_units( # Convert to ug/mL full_df, value_col=value_col, unit_col="standard_units", ) full_df = flag_unit_conversion(full_df) full_df = convert_dose_units( # Convert to mg/kg full_df, value_col=value_col, unit_col="standard_units", ) full_df = flag_unit_conversion(full_df) full_df = convert_time_units( # Conver to hr full_df, value_col=value_col, unit_col="standard_units", ) full_df = flag_unit_conversion(full_df) # Filter out activities with standard_relation not in the user-selected values # This is important because flag_censored_activity_comment may change '=' to '<' if "standard_relation" in full_df.columns and standard_relation is not None: excluded_relations = ~full_df["standard_relation"].isin(standard_relation) num_excluded = excluded_relations.sum() if num_excluded > 0: logger.info( f"Filtering out {num_excluded} activities with standard_relation not in {standard_relation}. " "These activities will be flagged for removal and saved to the _removed_subset file." ) # Flag these activities for removal full_df.loc[excluded_relations, DATA_DROPPING_COMMENT] = full_df.loc[ excluded_relations, DATA_DROPPING_COMMENT ].fillna("") + ( full_df.loc[excluded_relations, DATA_DROPPING_COMMENT] .apply(lambda x: "; " if x and str(x).strip() else "") .fillna("") + f"Standard relation not in selected values {standard_relation}" ) # Filter assays by size if min_assay_size is not None or max_assay_size is not None: logger.info( f"Filtering assays by size: min={min_assay_size}, max={max_assay_size}. " "Assays with insufficient size will be flagged for removal." ) full_df = full_df.pipe(flag_min_assay_size, min_assay_size=min_assay_size).pipe( flag_max_assay_size, max_assay_size=max_assay_size ) # Filter by minimum assay overlap if min_assay_overlap > 0 and not full_df.empty: logger.info(f"Filtering assays based on minimum overlap of {min_assay_overlap} compounds.") full_df = flag_insufficient_assay_overlap( df=full_df, min_overlap=min_assay_overlap, molecule_col=MOLECULE_ID, assay_col=ASSAY_ID, target_col=TARGET_ID, comment_col=DATA_DROPPING_COMMENT, ) # Columns to remove after standardization # Note: standard_value and standard_units are always preserved as multivalue columns cols_to_remove_post_standardization = [ "type", "relation", "units", "value", "type", # "description", ] logger.debug(f"All fetched bioactivity types from ChEMBL: {full_df.standard_type.unique().tolist()}") logger.debug(f"Filtering for bioactivity types: {biotypes}") # drop rows without chemical structures no_smiles_mask = full_df.canonical_smiles.isna() if no_smiles_mask.any(): _info = full_df[no_smiles_mask].iloc[:, :6] logger.info(f"Dropping rows with missing canonical smiles:\n{_info}") full_df = full_df.drop(index=_info.index).reset_index(drop=True) stdzer = ChemStandardizer( from_smi=True, n_jobs=8, verbose=False, isomeric=chirality, progress=True, chunk_size=1000 ) # Filter by bioactivity_type only if it's not None if bioactivity_type is not None: df = full_df.query("standard_type.isin(@bioactivity_type)") else: df = full_df.copy() df = ( df # standardize the smiles & clean possible solvents & salts from the string .pipe(flag_missing_canonical_smiles) .assign(standard_smiles=lambda x: stdzer(x["canonical_smiles"])) .dropna(subset=["standard_smiles"]) # drop if no structure is found .pipe(flag_salt_or_solvent_removal) .assign(final_smiles=lambda x: x["standard_smiles"].apply(clean_mixtures)) .drop(columns="standard_smiles") .rename(columns={"final_smiles": "standard_smiles"}) .drop(columns=[c for c in cols_to_remove_post_standardization if c in full_df.columns]) .reset_index(drop=True) .copy() ) # make sure we don't have Nan, can result from merging pChEMBL-lacking calculated values df[DATA_PROCESSING_COMMENT] = df[DATA_PROCESSING_COMMENT].fillna("") # Raise error if df is empty after critical processing steps if df.empty: func_params = signature(get_standardize_and_clean_workflow).parameters local_vars = locals() # Filter out df from params to avoid large object in error message error_params = { k: v for k, v in local_vars.items() if k in func_params and k != "full_df" and k != "queried_df" } raise BioactivitiesNotFoundError(parameters=error_params) # find mixtures in the data mask = df["standard_smiles"].str.contains(".", regex=False) n_mixtures = mask.sum() if n_mixtures > 0: df = flag_to_remove_mixture_compounds(df) logger.info(f"Number of mixtures: {mask.sum()}") # Search for undefined stereocenters within the remaining data if drop_unassigned_chiral: # here we have the problem with the "." SMILES # Use parallel processing for finding undefined stereocenters logger.debug(f"Finding undefined stereocenters in {len(df)} SMILES strings using parallel processing") applier = ParallelApplier( find_undefined_stereocenters, df["standard_smiles"].tolist(), n_jobs=8, # Use 8 cores by default backend="loky", custom_desc="Find undefined stereocenters", logger=logger, chunk_size=200, ) undefined_stereo_lists = applier() undefined_stereo_counts = [len(x) for x in undefined_stereo_lists] df = df.assign(undefined_stereocenters=undefined_stereo_counts).pipe(flag_undefined_stereochemistry) logger.trace(f'Unassigned stereocenters: {df["undefined_stereocenters"].unique().tolist()}') undefined_stereo_mask = df["undefined_stereocenters"] > 0 if undefined_stereo_mask.any(): logger.info(f"Flagging {undefined_stereo_mask.sum()} rows with undefined stereocenters.") logger.debug(df[undefined_stereo_mask].iloc[:, :5]) if df.empty: logger.warning("All data points have been dropped due to undefined stereocenters!!") return pd.DataFrame() # Strict mutant removal based on assay_description if strict_mutant_removal: df = flag_strict_mutant_assays(df, strict_mutant_removal=True) if df.empty: logger.warning( "All data points have been dropped after strict mutant removal based on assay_description." ) return pd.DataFrame() # for the duplication we try to find the same molecule identifiers (molID, SMILES) and # activity outcomes (targetID, organismID, standard_value, standard_relation), but reported # in different ChEMBL documents (different papers) so we can keep same-readouts reported # by two assays performed in the same paper ! df = flag_inter_document_duplication(df).sort_values(WORKFLOW_SAVE_SORTED_BY).reset_index(drop=True) # This part needs to be removed prior to data aggregation. Here, we have either # inorganic compounds (SMILES removed from the salt removal step), mixtures (SMILES with "."), or # compounds with missing activity values, which are needed for the statistics missing_smiles_patt = r"Missing Standard SMILES|Missing SMILES|Mixture in SMILES" only_salt_entry_patt = r"^\.+$" # if only salts are present, SMILES will be just "." # Build the query dynamically based on which columns exist and which value_col is used query_parts = [ "data_dropping_comment.str.contains(@missing_smiles_patt, na=False, regex=True)", r"standard_smiles.str.contains(@only_salt_entry_patt, regex=True)", ] # Only filter by value_col if it exists in the dataframe if value_col in df.columns: query_parts.append(f"{value_col}.isna()") removed_subset = df.query(" | ".join(query_parts)).copy() if output_path is not None: suffixes = "".join(output_path.suffixes) new_name = output_path.stem.split(".")[0] + "_removed_subset" + suffixes save_dataframe(removed_subset, output_path.with_name(new_name)) df = df.drop(index=removed_subset.index) if save_not_aggregated and output_path is not None: suffixes = "".join(output_path.suffixes) new_name = output_path.stem.split(".")[0] + "_not_aggregated" + suffixes save_dataframe(df, output_path.with_name(new_name)) return df
[docs] def aggregate_data( df, chirality: bool, metadata_cols: list[str] = [], extra_id_cols: list[str] = [], extra_multival_cols: list[str] = [], aggregate_mutants: bool = False, output_path: Optional[Union[str, Path]] = None, compound_equality: Literal["mixed_fp", "connectivity", "smiles"] = "connectivity", value_col: str = "pchembl_value", ): """Aggregate the data obtained from ChEMBL by: 1) Calculate fingerprints and use those to identify same-structure compounds; 2) Identify identical arrays from fingerprints and aggregate the data; Aggregated data will contain the original data separated by a semicolon and calculate the mean, median, standard deviation, median absolute deviation, and value counts for the pchembl values. Args: df: dataframe output from `CompoundMapper.cli.workflow.fetch_standardize_and_clean_workflow` chirality: toggle chiral-sensitive fingerprints for identifying same molecules extra_id_cols: additional columns to use as identifiers for the aggregation. Passing `["assay_chembl_id"]` to this argument, for example, will only aggregate the data if the compound is the same and the assay is the same. extra_multival_cols: list of extra columns that you'd like to keep as aggregated values in the final dataframe. Caveat: these columns will be displayes as (str) separated by `;` in the final dataframe. Defaults to []. aggregate_mutants: if true, will aggregate data solely based on the target_chembl_id, regardless of the mutation flag in ChEMBL. Defaults to False. output_path: path to save the aggregated data compound_equality: How to identify same compounds in the dataset. If "mixed_fp", uses mixed fingerprints (ECFP4 + RDKitFP) to identify same compounds. If "connectivity", uses the first part of the InChI key (connectivity) to identify same compounds. If "smiles", uses standardized SMILES strings directly. Defaults to "connectivity". value_col: Column name containing the values to aggregate statistics on. Defaults to "pchembl_value". Use "standard_value" for non-pChEMBL data (e.g., % inhibition). Returns: pd.DataFrame: the aggregated data """ current_extra_id_cols = list(extra_id_cols) # mutable copy connectivity_writer = InchiHandling( convert_to="connectivity", n_jobs=4, progress=True, from_smi=True, chunk_size=None ) # Track whether we pre-computed connectivity to avoid recalculating after aggregation precomputed_connectivity = None if compound_equality == "mixed_fp": fps = calculate_mixed_FPs( # Fingerprints are calculated to identify same molecules in the dataset df["standard_smiles"].tolist(), n_jobs=8, morgan_kwargs={"useChirality": chirality}, chunk_size=50 ) df = df.assign(id_array=fps) elif compound_equality == "connectivity": def _strip_stereo(smi): mol = Chem.MolFromSmiles(smi) if mol is None: return smi Chem.RemoveStereochemistry(mol) return Chem.MolToSmiles(mol) if chirality: logger.warning( "Connectivity-based compound equality merges stereoisomers!!!" "Stripping stereochemistry from standard_smiles to avoid " "retaining an arbitrary enantiomer's SMILES in the output." ) df["standard_smiles"] = df["standard_smiles"].apply(_strip_stereo) connectivities = connectivity_writer(df["standard_smiles"].tolist()) df = df.assign(id_array=connectivities) # Store connectivity before censored modification so we can reuse it after aggregation precomputed_connectivity = pd.Series(connectivities, index=df.index) elif compound_equality == "smiles": df = df.assign(id_array=df["standard_smiles"].values) else: raise ValueError( f"Invalid compound_equality value: {compound_equality}. " "Expected 'mixed_fp', 'connectivity', or 'smiles'." ) # For censored measurements (!=), include relation and value in the compound identifier # so they are only aggregated if they have the same value AND relation. # NaN standard_relation (e.g., AstraZeneca PPB assays) is treated as non-censored. censored_mask = df["standard_relation"].notna() & df["standard_relation"].ne("=") has_censored = censored_mask.any() if has_censored: logger.info( "Detected censored measurements (standard_relation != '='). " f"These will only be aggregated if they have identical relation AND {value_col}." ) # Round value to some decimal places to avoid floating point precision issues if value_col == "pchembl_value": rounded_value = df[value_col].round(2).astype(str) elif value_col == "standard_value": rounded_value = df[value_col].round(4).astype(str) df.loc[censored_mask, "id_array"] = ( df.loc[censored_mask, "id_array"].astype(str) + "_" + df.loc[censored_mask, "standard_relation"] + "_" + rounded_value[censored_mask] ) # Treat NaN standard_relation as "=" (exact measurement) for aggregation. # process_repeat_mols groups by standard_relation, and pandas drops NaN groups. df["standard_relation"] = df["standard_relation"].fillna("=") # Here we have a repeat index for compounds across all fetched data. Processing which repeats # get aggregated (e.g.: same target ID, same `extra_id_cols`, etc) is done in `process_repeat_mols`. repeats_idxs = repeated_indices_from_array_series(df["id_array"]) # Build mapping from standard_smiles to connectivity before aggregation # (used to avoid recalculating connectivity after aggregation) if precomputed_connectivity is not None: smiles_to_connectivity = dict(zip(df["standard_smiles"], precomputed_connectivity)) include_metadata = [ "doc_type", "doi", "journal", "year", "chembl_release", *extra_multival_cols, DATA_DROPPING_COMMENT, DATA_PROCESSING_COMMENT, ] final_data = process_repeat_mols( df, repeats_idxs, solve_strat="keep", extra_id_cols=current_extra_id_cols, chirality=chirality, extra_multival_cols=include_metadata, aggregate_mutants=aggregate_mutants, value_col=value_col, ) # Assign connectivity column - reuse precomputed values when available. # The mapping may miss molecules whose SMILES changed during canonicalization # (e.g., stereochemistry stripped when chirality=False), so recompute for any misses. if precomputed_connectivity is not None: final_data = final_data.assign(connectivity=final_data["smiles"].map(smiles_to_connectivity)) missing_mask = final_data["connectivity"].isna() if missing_mask.any(): recomputed = connectivity_writer(final_data.loc[missing_mask, "smiles"].tolist()) final_data.loc[missing_mask, "connectivity"] = recomputed else: final_data = final_data.assign(connectivity=lambda x: connectivity_writer(x["smiles"].tolist())) # reorder the columns so that connectivity comes first and processing & dropping comes last xtra_cols = [DATA_PROCESSING_COMMENT, DATA_DROPPING_COMMENT] first_columns = ["connectivity", *current_extra_id_cols, "smiles"] last_columns = final_data.columns.difference(first_columns + xtra_cols).tolist() + xtra_cols cols = ["connectivity", *current_extra_id_cols, "smiles", *last_columns] final_data = final_data[cols].sort_values(AGGREGATE_SAVE_SORTED_BY).reset_index(drop=True) _warn_info_post_aggregation_repeats( final_data, extra_id_cols=extra_id_cols, aggregate_mutants=aggregate_mutants, value_col=value_col ) if output_path is not None: save_dataframe(final_data, output_path) return final_data
[docs] def re_aggregate_data( df: pd.DataFrame, chirality: bool, extra_id_cols: list[str] = [], extra_multival_cols: list[str] = [], aggregate_mutants: bool = False, output_path: Optional[Union[str, Path]] = None, compound_equality: Literal["mixed_fp", "connectivity", "smiles"] = "connectivity", ) -> pd.DataFrame: """Re-aggregate the data obtained from the `aggregate_data` method after dataset explosion. Useful for exploring the effect of different `extra_id_cols` and other parameters. Args: df: dataframe output from `aggregate_data` chirality: toggle chiral-sensitive fingerprints for identifying same molecules extra_id_cols: additional columns to use as identifiers for the aggregation. Passing `["assay_chembl_id"]` to this argument, for example, will only aggregate the data if the compound is the same and the assay is the same. extra_multival_cols: list of extra columns that you'd like to keep as aggregated values in the final dataframe. Caveat: these columns will be displayes as (str) separated by `|` (pipe) in the final dataframe. Defaults to []. aggregate_mutants: if true, will aggregate data solely based on the target_chembl_id, regardless of the mutation flag in ChEMBL. Defaults to False. output_path: path to save the aggregated data compound_equality: How to identify same compounds in the dataset. If "mixed_fp", uses mixed fingerprints (ECFP4 + RDKitFP) to identify same compounds. If "connectivity", uses the first part of the InChI key (connectivity) to identify same compounds. If "smiles", uses standardized SMILES strings directly. Defaults to "connectivity". Returns: pd.DataFrame: the re-aggregated data """ if "processed_smiles" in df.columns: df = df.rename(columns={"processed_smiles": "standard_smiles"}) if compound_equality == "connectivity" and "connectivity" not in df.columns: raise ValueError("Input DataFrame must contain a 'connectivity' column.") if "standard_smiles" not in df.columns: raise ValueError("Input DataFrame must contain a 'standard_smiles' column.") if "smiles" not in df.columns: raise ValueError( "This method expects the output from CompoundMapper's CLI, which includes a 'smiles' column." ) if compound_equality == "mixed_fp": fps = calculate_mixed_FPs( df["standard_smiles"].tolist(), n_jobs=8, morgan_kwargs={"useChirality": chirality} ) id_array = pd.Series(fps, index=df.index) elif compound_equality == "connectivity": id_array = df["connectivity"] elif compound_equality == "smiles": id_array = df["standard_smiles"] else: raise ValueError( f"Invalid compound_equality value: {compound_equality}. " "Expected 'mixed_fp', 'connectivity', or 'smiles'." ) # For censored measurements (!=), include relation and pchembl_value in the compound identifier # so they are only aggregated if they have the same value AND relation if "standard_relation" in df.columns: has_censored = df["standard_relation"].ne("=").any() if has_censored: logger.info( "Detected censored measurements (standard_relation != '='). " "These will only be aggregated if they have identical relation AND pchembl_value." ) # Round pchembl_value to 2 decimal places to avoid floating point precision issues rounded_pchembl = df["pchembl_value"].round(2).astype(str) # For censored measurements, append relation + value to the id_array censored_mask = df["standard_relation"] != "=" id_array = id_array.copy() # Create a copy to avoid modifying the original id_array.loc[censored_mask] = ( id_array.loc[censored_mask].astype(str) + "_" + df.loc[censored_mask, "standard_relation"] + "_" + rounded_pchembl[censored_mask] ) repeats_idxs = repeated_indices_from_array_series(id_array) include_metadata = [ "doc_type", "doi", "journal", "year", "chembl_release", *extra_multival_cols, DATA_DROPPING_COMMENT, DATA_PROCESSING_COMMENT, ] for col in include_metadata: # make sure columns exist if col not in df.columns: raise ValueError( f"Column '{col}' is required in the DataFrame but is missing. " "Please ensure that the DataFrame contains all necessary columns." ) final_data = process_repeat_mols( # recalculate the stats given new conditions df, repeats_idxs, solve_strat="keep", extra_id_cols=extra_id_cols, chirality=chirality, extra_multival_cols=include_metadata, aggregate_mutants=aggregate_mutants, ) connectivity_writer = InchiHandling( convert_to="connectivity", n_jobs=8, progress=True, from_smi=True, chunk_size=50 ) final_data = final_data.assign(connectivity=lambda x: connectivity_writer(x["smiles"].tolist())) # Reorder columns as in the original aggregate_data function xtra_cols = [DATA_PROCESSING_COMMENT, DATA_DROPPING_COMMENT] for col in xtra_cols: if col not in final_data.columns: raise ValueError( f"Column '{col}' is required in the DataFrame but is missing. " "Please ensure that the DataFrame contains all necessary columns." ) # reorder the columns so that connectivity comes first and processing & dropping comes last xtra_cols = [DATA_PROCESSING_COMMENT, DATA_DROPPING_COMMENT] first_columns = ["connectivity", *extra_id_cols, "smiles"] last_columns = final_data.columns.difference(first_columns + xtra_cols).tolist() + xtra_cols cols = ["connectivity", *extra_id_cols, "smiles", *last_columns] final_data = final_data[cols].sort_values(AGGREGATE_SAVE_SORTED_BY).reset_index(drop=True) _warn_info_post_aggregation_repeats( final_data, extra_id_cols=extra_id_cols, aggregate_mutants=aggregate_mutants ) if output_path is not None: save_dataframe(final_data, output_path) return final_data