Source code for Capricho.analysis

"""Analysis utilities for CAPRICHO bioactivity data comparability studies"""

from enum import Enum
from itertools import combinations
from typing import List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from loguru import logger as log
from matplotlib import colormaps
from scipy import stats

def r2_score(y_true, y_pred):
    ss_res = np.sum((np.asarray(y_true) - np.asarray(y_pred)) ** 2)
    ss_tot = np.sum((np.asarray(y_true) - np.mean(y_true)) ** 2)
    return 1 - ss_res / ss_tot


class ProcessingComment(str, Enum):
    """Processing comments added during data curation (non-dropping flags).

    Note: UNIT_CONVERTED is a pattern-based comment. Actual comments include the
    original and target units, e.g., "Unit converted to nM from uM".
    """

    CALCULATED_PCHEMBL = "Calculated pChEMBL"
    SALT_SOLVENT_REMOVED = "Salt/solvent removed"
    PCHEMBL_DUPLICATION_ACROSS_DOCUMENTS = "pChEMBL Duplication Across Documents"
    CORRECTED_STANDARD_RELATION = "Corrected standard_relation from = to < (censored activity_comment)"
    UNIT_CONVERTED = "Unit converted to"  # Example: "Unit converted to nM from uM"


class DroppingComment(str, Enum):
    """Dropping comments indicating data quality issues (dropping flags).

    Note: Assay size and insufficient overlap values shown are examples. Actual thresholds
    depend on user parameters and should be matched using pattern-based functions.
    """

    DATA_VALIDITY_COMMENT = "Data Validity Comment Present"
    POTENTIAL_DUPLICATE = "Potential Duplicate"
    UNDEFINED_STEREOCHEMISTRY = "Undefined Stereochemistry"
    MUTATION_KEYWORD = "Mutation keyword in assay description"
    ASSAY_SIZE_TOO_LARGE = "Assay size >"  # Example: "Assay size > 100"
    ASSAY_SIZE_TOO_SMALL = "Assay size <"  # Example: "Assay size < 20"
    UNIT_ANNOTATION_ERROR = "Unit Annotation Error"
    MISSING_DOCUMENT_DATE = "Missing document date"
    MIXTURE_IN_SMILES = "Mixture in SMILES"
    INSUFFICIENT_ASSAY_OVERLAP = (
        "Insufficient assay overlap"  # Example: "Insufficient assay overlap (min_overlap=5)"
    )
    INSUFFICIENT_ASSAY_OVERLAP_WITH_METADATA = "Insufficient assay overlap with metadata matching"  # Example: "Insufficient assay overlap with metadata matching (min_overlap=5)"


def normalize_comment_pattern(comment: str) -> str:
    """Normalize comment by removing dynamic values (e.g., threshold numbers).

    Args:
        comment: Raw comment string from data.

    Returns:
        Normalized pattern for matching. For example:
        - "Assay size < 20" -> "Assay size <"
        - "Assay size > 100" -> "Assay size >"
        - "Insufficient assay overlap (min_overlap=5)" -> "Insufficient assay overlap"
        - "Insufficient assay overlap with metadata matching (min_overlap=5)" -> "Insufficient assay overlap with metadata matching"
        - "Unit converted to nM from uM" -> "Unit converted to"
        - "Unit Annotation Error" -> "Unit Annotation Error"
    """
    # Handle assay size patterns by removing the threshold number
    if comment.startswith("Assay size <"):
        return "Assay size <"
    elif comment.startswith("Assay size >"):
        return "Assay size >"
    # Handle insufficient assay overlap patterns by removing the parameter
    elif comment.startswith("Insufficient assay overlap with metadata matching"):
        return "Insufficient assay overlap with metadata matching"
    elif comment.startswith("Insufficient assay overlap"):
        return "Insufficient assay overlap"
    # Handle unit conversion patterns by removing the specific units
    elif comment.startswith("Unit converted to"):
        return "Unit converted to"
    return comment


def extract_assay_threshold(comment: str) -> str:
    """Extract the threshold value from an assay size comment.

    Args:
        comment: Comment like "Assay size < 20" or "Assay size > 100".

    Returns:
        The threshold number as string, or empty string if not found.
    """
    if comment.startswith("Assay size <"):
        return comment.split("<")[1].strip()
    elif comment.startswith("Assay size >"):
        return comment.split(">")[1].strip()
    return ""


[docs] def get_all_comments() -> list[str]: """Get all comment types in order for plotting. Returns: List of all comment patterns (dropping and processing combined). Assay size, insufficient overlap, and unit conversion comments are patterns without specific thresholds/units. """ return [ DroppingComment.DATA_VALIDITY_COMMENT.value, DroppingComment.POTENTIAL_DUPLICATE.value, DroppingComment.UNDEFINED_STEREOCHEMISTRY.value, DroppingComment.MUTATION_KEYWORD.value, DroppingComment.ASSAY_SIZE_TOO_LARGE.value, DroppingComment.ASSAY_SIZE_TOO_SMALL.value, DroppingComment.INSUFFICIENT_ASSAY_OVERLAP.value, DroppingComment.INSUFFICIENT_ASSAY_OVERLAP_WITH_METADATA.value, DroppingComment.UNIT_ANNOTATION_ERROR.value, DroppingComment.MISSING_DOCUMENT_DATE.value, DroppingComment.MIXTURE_IN_SMILES.value, ProcessingComment.SALT_SOLVENT_REMOVED.value, ProcessingComment.CALCULATED_PCHEMBL.value, ProcessingComment.CORRECTED_STANDARD_RELATION.value, ProcessingComment.PCHEMBL_DUPLICATION_ACROSS_DOCUMENTS.value, ProcessingComment.UNIT_CONVERTED.value, ]
def deaggregate_data(data: pd.DataFrame, sep_str: str = "|") -> pd.DataFrame: """De-aggregate data by splitting pipe-delimited columns back into individual rows. Takes aggregated data where multiple values are stored in pipe-delimited strings and explodes them back into separate rows, effectively reversing the aggregation. Only rows containing the separator in at least one column are exploded. Args: data: DataFrame with potentially aggregated (pipe-delimited) columns. sep_str: Separator string used to delimit multiple values in columns. Returns: DataFrame with de-aggregated data where each measurement is a separate row. """ cols_with_pipe = data.apply(lambda col: col.astype(str).str.contains(sep_str, regex=False)).any() cols_with_pipe = np.compress(cols_with_pipe.values, cols_with_pipe.index).tolist() if not cols_with_pipe: return data.copy() first_pipe_col = cols_with_pipe[0] mask = data[first_pipe_col].astype(str).str.contains(sep_str, regex=False) aggregated_rows = data[mask] if len(aggregated_rows) == 0: return data.copy() aggregated_rows = aggregated_rows.copy() for col in cols_with_pipe: aggregated_rows[col] = aggregated_rows[col].fillna("").astype(str).str.split(sep_str) exploded = aggregated_rows.explode(column=cols_with_pipe) deaggregated = pd.concat([data.drop(index=aggregated_rows.index), exploded]) return deaggregated.reset_index(drop=True) def deduplicate_aggregated_values( data: pd.DataFrame, value_column: str = "pchembl_value", sep_str: str = "|", ) -> pd.DataFrame: """Remove duplicate values within aggregated rows based on the value column. When data is aggregated, identical values from different documents may appear multiple times (e.g., "8.00|8.00|8.00|6.92"). This function removes duplicate values, keeping only the first occurrence of each unique value within each row. This is useful for removing cross-document duplicates before recalculating statistics (mean, median, std), as duplicate values can skew the distribution. Args: data: DataFrame with aggregated (pipe-delimited) columns. value_column: Column containing the values to deduplicate on. Defaults to "pchembl_value". sep_str: Separator string used to delimit multiple values in columns. Returns: DataFrame with deduplicated values. Rows that had duplicates will have fewer pipe-delimited entries. Single-value rows are unchanged. """ if len(data) == 0: return data.copy() data = data.copy() # Find columns with pipe separators (multi-value columns) cols_with_pipe = data.apply(lambda col: col.astype(str).str.contains(sep_str, regex=False)).any() cols_with_pipe = np.compress(cols_with_pipe.values, cols_with_pipe.index).tolist() if not cols_with_pipe or value_column not in cols_with_pipe: return data # Process each row that has aggregated values mask = data[value_column].astype(str).str.contains(sep_str, regex=False) if not mask.any(): return data def dedupe_row(row): """Deduplicate a single row based on value_column, keeping first occurrence.""" values = str(row[value_column]).split(sep_str) # Find indices of first occurrence of each unique value seen = {} keep_indices = [] for i, val in enumerate(values): if val not in seen: seen[val] = i keep_indices.append(i) # If no duplicates found, return row unchanged if len(keep_indices) == len(values): return row # Apply the same indices to all pipe-delimited columns new_row = row.copy() for col in cols_with_pipe: col_values = str(row[col]).split(sep_str) # Handle case where column has fewer values than value_column if len(col_values) >= len(values): new_values = [col_values[i] for i in keep_indices] new_row[col] = sep_str.join(new_values) return new_row # Apply deduplication to rows with aggregated values deduplicated_rows = data.loc[mask].apply(dedupe_row, axis=1) data.loc[mask] = deduplicated_rows return data def filter_aggregated_dropping_flags( data: pd.DataFrame, flags: list[str], comment_column: str = "data_dropping_comment", value_column: str = "pchembl_value", sep_str: str = "|", ) -> pd.DataFrame: """Remove individual flagged measurements from aggregated rows and recalculate statistics. For aggregated data (pipe-separated columns), this function checks each measurement position individually. Only measurements whose comment matches a flag are removed. Rows where ALL measurements are flagged are removed entirely. Non-aggregated rows are filtered at the row level (same as filter_dropping_flags). Args: data: DataFrame with potentially aggregated (pipe-delimited) columns. flags: List of flag strings to match against (substring match). comment_column: Column containing quality flags. Defaults to "data_dropping_comment". value_column: Column containing activity values for stats recalculation. Defaults to "pchembl_value". sep_str: Separator string used to delimit multiple values in columns. Returns: DataFrame with flagged measurements removed and statistics recalculated. """ from .core.pandas_helper import assign_stats, filter_dropping_flags if not flags: return data if comment_column not in data.columns: log.warning(f"Column '{comment_column}' not found in DataFrame. No filtering applied.") return data if len(data) == 0: return data.copy() # Detect pipe-separated columns cols_with_pipe = data.apply(lambda col: col.astype(str).str.contains(sep_str, regex=False)).any() cols_with_pipe = np.compress(cols_with_pipe.values, cols_with_pipe.index).tolist() if not cols_with_pipe or comment_column not in cols_with_pipe: # Non-aggregated data: fall back to row-level filtering return filter_dropping_flags(data, flags, column=comment_column) data = data.copy() # Per-flag logging (how many rows contain each flag) for flag in flags: flag_mask = data[comment_column].fillna("").astype(str).str.contains(flag, regex=False) n_flagged = flag_mask.sum() if n_flagged > 0: log.info(f"Rows with flag '{flag}': {n_flagged}") # Identify aggregated rows (value_column or comment_column has pipes) check_col = value_column if value_column in cols_with_pipe else comment_column is_aggregated = data[check_col].astype(str).str.contains(sep_str, regex=False) # Track rows to drop entirely rows_to_drop = [] # Track rows that were modified (need stats recalculation) rows_modified = [] for idx in data.index: if is_aggregated.loc[idx]: # Aggregated row: measurement-level filtering comments = str(data.at[idx, comment_column]).split(sep_str) keep_indices = [] for i, comment in enumerate(comments): if not _measurement_has_flag(comment, flags): keep_indices.append(i) if len(keep_indices) == 0: # All measurements flagged -> remove row rows_to_drop.append(idx) elif len(keep_indices) < len(comments): # Partial filtering: keep only clean measurements for col in cols_with_pipe: col_values = str(data.at[idx, col]).split(sep_str) if len(col_values) >= len(comments): new_values = [col_values[i] for i in keep_indices] data.at[idx, col] = sep_str.join(new_values) rows_modified.append(idx) # else: all clean, no change needed else: # Non-aggregated row: row-level filtering comment = str(data.at[idx, comment_column]) if pd.notna(data.at[idx, comment_column]) else "" if _measurement_has_flag(comment, flags): rows_to_drop.append(idx) n_removed = len(rows_to_drop) n_pruned = len(rows_modified) n_total = n_removed + n_pruned if n_removed > 0: log.info(f"Rows fully removed (all measurements flagged): {n_removed}") if n_pruned > 0: log.info(f"Rows with individual measurements pruned: {n_pruned}") if n_total > 0: log.info(f"Total rows affected: {n_total}/{len(data)} ({n_total / len(data) * 100:.1f}%)") data = data.drop(index=rows_to_drop) if len(data) == 0: return data.reset_index(drop=True) # Recalculate stats for modified rows (and all aggregated rows for consistency) if rows_modified: use_geometric = value_column == "pchembl_value" data = assign_stats(data, value_col=value_column, use_geometric=use_geometric) return data.reset_index(drop=True) def _measurement_has_flag(comment: str, flags: list[str]) -> bool: """Check if a single measurement's comment matches any of the given flags. Uses substring matching, consistent with filter_dropping_flags semantics. Args: comment: Comment string for a single measurement. flags: List of flag patterns to check. Returns: True if any flag is found in the comment. """ if not comment or comment == "nan": return False for flag in flags: if flag in comment: return True return False def resolve_annotation_errors( data: pd.DataFrame, strategy: str = "first", mol_id_col: str = "molecule_chembl_id", assay_id_col: str = "assay_chembl_id", value_col: str = "pchembl_value", year_col: str = "year", ) -> pd.DataFrame: """Resolve unit annotation errors by keeping measurements from earliest documents. Unit annotation errors occur when measurements for the same molecule across different assays differ by exactly 3.0 or 6.0 log units (suggesting 1000x or 1,000,000x unit conversion errors). This function detects such pairs and keeps only the measurement from the earliest document, assuming later measurements copied the value incorrectly. This function operates on exploded (non-aggregated) data where each row is a single measurement. Args: data: DataFrame with exploded bioactivity data (one measurement per row). strategy: Resolution strategy. Currently only "first" is supported, which keeps the measurement from the earliest document year. mol_id_col: Column name for molecule identifiers. assay_id_col: Column name for assay identifiers. value_col: Column name for activity values (e.g., pchembl_value). year_col: Column name for document year. Returns: DataFrame with annotation errors resolved (problematic measurements removed). """ if len(data) == 0: return data.copy() if strategy != "first": raise ValueError(f"Unknown strategy: {strategy}. Only 'first' is supported.") data = data.copy() # Convert value column to numeric for comparison data["__value_numeric__"] = pd.to_numeric(data[value_col], errors="coerce") # Convert year to numeric, treating 'nan' string as NaN data["__year_numeric__"] = data[year_col].replace("nan", np.nan) data["__year_numeric__"] = pd.to_numeric(data["__year_numeric__"], errors="coerce") # Track indices to remove indices_to_remove = set() pairs_detected = 0 log.trace(f"Starting annotation error detection for {len(data)} measurements") # Group by molecule and find problematic pairs for mol_id, group in data.groupby(mol_id_col): if len(group) < 2: continue # Get all pairs of measurements from different assays for i, (idx_a, row_a) in enumerate(group.iterrows()): for idx_b, row_b in list(group.iterrows())[i + 1 :]: # Skip if same assay if row_a[assay_id_col] == row_b[assay_id_col]: continue # Check if values differ by ~3.0 or ~6.0 val_a = row_a["__value_numeric__"] val_b = row_b["__value_numeric__"] if pd.isna(val_a) or pd.isna(val_b): continue diff = abs(val_a - val_b) is_annotation_error = np.isclose(diff, 3.0, rtol=1e-9, atol=1e-9) or np.isclose( diff, 6.0, rtol=1e-9, atol=1e-9 ) if not is_annotation_error: continue pairs_detected += 1 # Found a problematic pair - resolve by keeping earliest year_a = row_a["__year_numeric__"] year_b = row_b["__year_numeric__"] assay_a = row_a[assay_id_col] assay_b = row_b[assay_id_col] log.trace( f"Detected annotation error pair for {mol_id}: " f"assay {assay_a} (value={val_a:.2f}, year={year_a}) vs " f"assay {assay_b} (value={val_b:.2f}, year={year_b}), " f"diff={diff:.2f}" ) # Handle NaN years: remove the one with NaN year, keep valid year if pd.isna(year_a) and pd.isna(year_b): # Both NaN - can't resolve, remove both log.trace(" -> Both years NaN, removing BOTH measurements") indices_to_remove.add(idx_a) indices_to_remove.add(idx_b) elif pd.isna(year_a): # A has NaN year, keep B log.trace( f" -> Year A is NaN, keeping B (assay {assay_b}, value={val_b:.2f}, year={year_b})" ) indices_to_remove.add(idx_a) elif pd.isna(year_b): # B has NaN year, keep A log.trace( f" -> Year B is NaN, keeping A (assay {assay_a}, value={val_a:.2f}, year={year_a})" ) indices_to_remove.add(idx_b) elif year_a <= year_b: # A is earlier or same, remove B log.trace( f" -> A is earlier/same ({year_a} <= {year_b}), " f"keeping A (assay {assay_a}, value={val_a:.2f}), removing B" ) indices_to_remove.add(idx_b) else: # B is earlier, remove A log.trace( f" -> B is earlier ({year_b} < {year_a}), " f"keeping B (assay {assay_b}, value={val_b:.2f}), removing A" ) indices_to_remove.add(idx_a) log.trace( f"Annotation error detection complete: found {pairs_detected} problematic pairs, " f"removing {len(indices_to_remove)} unique measurements" ) # Remove problematic rows result = data.drop(index=list(indices_to_remove)) # Clean up temporary columns result = result.drop(columns=["__value_numeric__", "__year_numeric__"]) return result.reset_index(drop=True) def recalculate_aggregated_stats( data: pd.DataFrame, value_column: str = "pchembl_value", sep_str: str = "|", ) -> pd.DataFrame: """Recalculate statistics (mean, median, std, counts) from pipe-delimited values. After deduplication or other modifications to the pipe-delimited value column, the pre-computed statistics (mean, median, std, counts) may be stale. This function recalculates them from the current pipe-delimited values. Args: data: DataFrame with aggregated data containing value_column and corresponding stats columns ({value_column}_mean, _median, _std, _counts). value_column: Column containing pipe-delimited values. Defaults to "pchembl_value". sep_str: Separator string used to delimit multiple values. Returns: DataFrame with recalculated statistics columns. """ if len(data) == 0: return data.copy() data = data.copy() # Stats column names mean_col = f"{value_column}_mean" median_col = f"{value_column}_median" std_col = f"{value_column}_std" counts_col = f"{value_column}_counts" def calc_stats(val_str): """Calculate stats from pipe-delimited string.""" if pd.isna(val_str) or val_str == "": return np.nan, np.nan, np.nan, 0 vals = str(val_str).split(sep_str) vals = [float(v) for v in vals if v.strip()] if len(vals) == 0: return np.nan, np.nan, np.nan, 0 elif len(vals) == 1: return vals[0], vals[0], 0.0, 1 else: return np.mean(vals), np.median(vals), np.std(vals, ddof=1), len(vals) # Apply calculation to each row stats = data[value_column].apply(lambda x: pd.Series(calc_stats(x))) stats.columns = [mean_col, median_col, std_col, counts_col] # Update the stats columns for col in [mean_col, median_col, std_col, counts_col]: if col in data.columns: data[col] = stats[col] else: data[col] = stats[col] return data
[docs] def explode_assay_comparability( subset: pd.DataFrame, sep_str: str = "|", extra_multival_cols: Optional[list[str]] = None, value_column: str = "pchembl_value", ) -> pd.DataFrame: """Explode dataset to create pairwise comparisons between assays for the same compound. Takes a subset of data where compounds have measurements across multiple assays (indicated by separator in columns) and creates all pairwise combinations for comparability analysis. Args: subset: DataFrame with multi-valued columns separated by sep_str. sep_str: Separator string used to delimit multiple values in columns. extra_multival_cols: Additional columns to treat as multi-valued. value_column: The column containing activity values to compare. Defaults to "pchembl_value" but can be set to "standard_value" for non-pChEMBL data (e.g., Caco-2 permeability, percent inhibition). Returns: DataFrame with exploded pairwise comparisons, with _x and _y suffixes for each pair. """ if extra_multival_cols is None: extra_multival_cols = [] elif not isinstance(extra_multival_cols, list): raise ValueError("extra_multival_cols must be a list of column names or None") singleval_cols = [ "connectivity", "target_chembl_id", "repeat", ] multival_cols = [ "activity_id", "assay_chembl_id", value_column, "data_processing_comment", "data_dropping_comment", "standard_type", "canonical_smiles", *extra_multival_cols, ] # Prevent overlap between singleval and multival columns for col in extra_multival_cols: if col in singleval_cols: singleval_cols.remove(col) # Auto-detect which columns are actually multi-valued (contain the separator). # Columns used as --id-columns during aggregation will be single-valued. # Skip this check for empty DataFrames to preserve original column classification. if len(subset) > 0: for col in list(multival_cols): if col in subset.columns: if not subset[col].astype(str).str.contains(sep_str, regex=False).any(): multival_cols.remove(col) singleval_cols.append(col) # Filter to only include columns that exist in the input dataframe singleval_cols = [col for col in singleval_cols if col in subset.columns] multival_cols = [col for col in multival_cols if col in subset.columns] exploded_subset = subset[ [ *singleval_cols, *multival_cols, ] ].apply(lambda x: x.str.split(sep_str) if x.name not in singleval_cols else x) for col in multival_cols: exploded_subset[col] = exploded_subset[col].apply( lambda x: [sep_str.join(y) for y in combinations(x, 2)] ) exploded_subset = exploded_subset.explode(multival_cols) # Handle empty DataFrame case if len(exploded_subset) == 0: # Create empty result with correct columns result_cols = [*singleval_cols] for col in multival_cols: result_cols.extend([f"{col}_x", f"{col}_y"]) result_cols.extend(["processing_comment", "dropping_comment"]) return pd.DataFrame(columns=result_cols) suffixes = ["_x", "_y"] for col in multival_cols: values = exploded_subset[col].apply(lambda x: x.split(sep_str)).values for idx, s in enumerate(suffixes): exploded_subset[f"{col}{s}"] = [v[idx] for v in values] exploded_subset.drop(columns=[col], inplace=True) # Combine processing and dropping comments from both assays exploded_subset = exploded_subset.assign( processing_comment=lambda x: x["data_processing_comment_x"].fillna("") + sep_str + x["data_processing_comment_y"].fillna(""), dropping_comment=lambda x: x["data_dropping_comment_x"].fillna("") + sep_str + x["data_dropping_comment_y"].fillna(""), ).query("assay_chembl_id_x != assay_chembl_id_y") # Clean up comment strings by removing duplicates and extra separators exploded_subset["dropping_comment"] = ( exploded_subset.dropping_comment.str.rstrip(sep_str) .str.lstrip(sep_str) .apply(lambda x: sep_str.join(np.unique(x.split(sep_str)))) ) exploded_subset["processing_comment"] = ( exploded_subset.processing_comment.str.rstrip(sep_str) .str.lstrip(sep_str) .apply(lambda x: sep_str.join(np.unique(x.split(sep_str)))) ) return exploded_subset
def format_units_latex(units: str) -> str: """Convert common unit patterns to LaTeX format. Handles scientific notation patterns like "10^-6 cm/s" and converts them to proper LaTeX math notation. Units already containing LaTeX ($) are passed through unchanged. Args: units: Unit string (e.g., "10^-6 cm/s", "nM", "$10^{-6}$ cm/s"). Returns: Unit string with LaTeX math notation where applicable. """ if not units: return "" # Already contains LaTeX - pass through if "$" in units: return units # Handle 10^N patterns (e.g., "10^-6 cm/s" → "$10^{-6}$ cm/s") import re pattern = r"10\^(-?\d+)" if re.search(pattern, units): return re.sub(pattern, r"$10^{\1}$", units) return units def format_axis_label( property_name: str = "value", log_transform: bool = False, units: Optional[str] = None, ) -> str: """Format axis label with LaTeX math notation for publication-quality figures. Args: property_name: Name of the measured property (e.g., "Permeability", "pChEMBL"). log_transform: If True, prepend -log10 in LaTeX format. units: Unit string. Supports patterns like "10^-6 cm/s" which are converted to LaTeX. Already-LaTeX units ("$...$") pass through. Returns: Formatted label string with LaTeX math notation. """ parts = [property_name] if log_transform and units: # e.g., "Permeability ($-\log_{10}$ [$10^{-6}$ cm/s])" formatted_units = format_units_latex(units) parts.append(f"($-\\log_{{10}}$ [{formatted_units}])") elif log_transform: # e.g., "Standard Value ($-\log_{10}$)" parts.append(r"($-\log_{10}$)") elif units: # e.g., "Permeability (10^-6 cm/s)" formatted_units = format_units_latex(units) parts.append(f"({formatted_units})") return " ".join(parts) def _log_comparability_metrics( xp: np.ndarray, yp: np.ndarray, label: str = "", is_log_scale: bool = True, rho: Optional[float] = None, r2: Optional[float] = None, tau: Optional[float] = None, ) -> None: """Log quantitative comparability metrics for pairwise assay comparisons. For log-scale data (pChEMBL or -log10 transformed), reports the fraction of pairs within ±0.3 and ±1.0 log units. For all data, reports Spearman rho, Kendall tau, and R². Args: xp: Array of x-values (possibly log-transformed). yp: Array of y-values (possibly log-transformed). label: Description of the data subset (e.g., flag name). is_log_scale: If True, compute ±0.3/±1.0 statistics. rho: Pre-computed Spearman rho. Computed from data if None. r2: Pre-computed R² score. Computed from data if None. tau: Pre-computed Kendall tau. Computed from data if None. """ n = len(xp) if n == 0: return abs_diff = np.abs(np.asarray(xp) - np.asarray(yp)) try: if rho is None: rho, _ = stats.spearmanr(xp, yp) if r2 is None: r2 = r2_score(xp, yp) if tau is None: tau, _ = stats.kendalltau(xp, yp) except ValueError: return lines = [f"COMPARABILITY: {label}{n:,} pairwise comparisons"] if is_log_scale: within_03 = int(np.sum(abs_diff <= 0.3)) within_10 = int(np.sum(abs_diff <= 1.0)) outside_10 = n - within_10 lines.append(f" Within ±0.3 log units: {within_03:>8,} ({within_03 / n * 100:5.1f}%)") lines.append(f" Within ±1.0 log units: {within_10:>8,} ({within_10 / n * 100:5.1f}%)") lines.append(f" Outside ±1.0 log units:{outside_10:>8,} ({outside_10 / n * 100:5.1f}%)") lines.append(f" Spearman rho: {rho:.3f} | Kendall tau: {tau:.3f} | R²: {r2:.3f}") log.info("\n".join(lines))
[docs] def plot_subset( subset: pd.DataFrame, title: str = "", color: str = "slategray", alpha: float = 0.3, figsize: Tuple[float, float] = (5, 5), value_column: str = "pchembl_value", log_transform: bool = False, log_scale_factor: float = 1.0, axis_label: Optional[str] = None, axis_limits: Optional[Tuple[float, float]] = None, reference_lines: bool = True, units: Optional[str] = None, ) -> Tuple[plt.Figure, plt.Axes]: """Create scatter plot comparing values across assays with correlation metrics. Args: subset: DataFrame with {value_column}_x and {value_column}_y columns. title: Plot title. color: Color for scatter points. alpha: Transparency for scatter points. figsize: Figure size as (width, height) tuple. value_column: Base name of the value column (without _x/_y suffix). Defaults to "pchembl_value". log_transform: If True, apply -log10 transformation to values before plotting. Use for concentration values (nM, µM, etc.) but NOT for percentages or pChEMBL. log_scale_factor: Scale factor for the unit of measurement (e.g., 1e-6 for values in 10^-6 cm/s units). The transformation becomes -log10(value * factor). For example, a value of 5 in 10^-6 cm/s units with factor=1e-6 gives -log10(5e-6) ≈ 5.3, producing pChEMBL-like positive values. axis_label: Custom axis label. If None, uses format_axis_label() to generate a LaTeX-formatted label based on value_column, log_transform, and units. axis_limits: Tuple of (min, max) for both axes. If None, defaults to (3, 12) for pchembl_value, (0, 100) for percentage data, or auto-detects from data. reference_lines: If True, draw identity and ±1/±0.3 reference lines. These are most meaningful for pChEMBL-scale data. units: Unit string for axis labels (e.g., "10^-6 cm/s"). Converted to LaTeX format automatically. Only used when axis_label is None. Returns: Tuple of (figure, axes) objects. """ fig, ax = plt.subplots(figsize=figsize) subset = subset.copy() x_col = f"{value_column}_x" y_col = f"{value_column}_y" subset[x_col] = subset[x_col].astype(float) subset[y_col] = subset[y_col].astype(float) xp = subset[x_col].copy() yp = subset[y_col].copy() if log_transform: # Apply -log10(value * scale_factor) transformation # Math: -log10(x * factor) = -log10(x) - log10(factor) = -log10(x) + offset # For factor=1e-6: offset = -log10(1e-6) = 6, giving positive pChEMBL-like values offset = -np.log10(log_scale_factor) if log_scale_factor != 1.0 else 0 # Use epsilon of 0.001 to handle near-zero values sensibly # (zeros would otherwise produce extreme outliers) xp = -np.log10(xp + 0.001) + offset yp = -np.log10(yp + 0.001) + offset ax.scatter( xp, yp, alpha=alpha, edgecolors="none", color=color, ) ax.set_title(title) # Determine axis limits if axis_limits is not None: lim_min, lim_max = axis_limits elif value_column == "pchembl_value": # pChEMBL values typically range from 3-12 lim_min, lim_max = 3, 12 else: # Auto-detect from (possibly transformed) data all_vals = np.concatenate([xp, yp]) data_min, data_max = np.nanmin(all_vals), np.nanmax(all_vals) margin = (data_max - data_min) * 0.1 lim_min = data_min - margin lim_max = data_max + margin if reference_lines and (value_column == "pchembl_value" or log_transform): # Reference lines are most meaningful for pChEMBL-scale data ax.plot((lim_min, lim_max), (lim_min, lim_max), "k-", label="Identity $(y=x)$") ax.plot((lim_min, lim_max), (lim_min - 1, lim_max - 1), "k--") ax.plot((lim_min, lim_max), (lim_min + 1, lim_max + 1), "k--", label="$y=x±1$") ax.plot((lim_min, lim_max), (lim_min - 0.3, lim_max - 0.3), "k-.") ax.plot((lim_min, lim_max), (lim_min + 0.3, lim_max + 0.3), "k-.", label="$y=x±0.3$") elif reference_lines: # For percentage or other linear data, just draw identity line ax.plot((lim_min, lim_max), (lim_min, lim_max), "k-", label="Identity $(y=x)$") ax.set_xlim(lim_min, lim_max) ax.set_ylim(lim_min, lim_max) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) # Calculate correlation metrics r, _ = stats.spearmanr(xp, yp) tau, _ = stats.kendalltau(xp, yp) r2 = r2_score(xp, yp) ax.text( 1.0, 0.175, rf"$R^2: {r2:.2f}$" + "\n" + rf"Spearman $\rho: {r:.2f}$" + "\n" + rf"Kendall $\tau: {tau:.2f}$", transform=ax.transAxes, verticalalignment="top", horizontalalignment="right", ) # Log quantitative comparability metrics is_log = value_column == "pchembl_value" or log_transform _log_comparability_metrics(xp.values, yp.values, label=title or "Overall", is_log_scale=is_log, rho=r, r2=r2, tau=tau) # Determine axis labels if axis_label is not None: label_base = axis_label elif value_column == "pchembl_value": label_base = "pChEMBL value" else: property_name = value_column.replace("_", " ").title() label_base = format_axis_label(property_name, log_transform=log_transform, units=units) ax.set_ylabel(f"Assay 2 {label_base}") ax.set_xlabel(f"Assay 1 {label_base}") handles, labels = ax.get_legend_handles_labels() ax.legend( handles, labels, title="", loc="upper center", bbox_to_anchor=(0.45, -0.15), ncol=3, frameon=False, ) return fig, ax
def build_query_string(comment: str, value_column: str = "pchembl_value") -> str: """Build query string for filtering data by specific comment flags. This function replicates the original notebook logic for querying exploded datasets. It handles both exact matches and pattern-based matches (e.g., for assay size). Args: comment: The comment pattern to search for (may be normalized pattern like "Assay size <"). value_column: Base name of the value column (without _x/_y suffix). Used for duplicate checking queries. Defaults to "pchembl_value". Returns: Query string suitable for pd.DataFrame.query(). """ # Since this data duplication is a flagged introduced by capricho, we'll just show the # data points that have x == y for this flag if comment == ProcessingComment.PCHEMBL_DUPLICATION_ACROSS_DOCUMENTS.value: return ( "(data_processing_comment_x.str.contains('pChEMBL Duplication Across Documents', regex=False) & " "data_processing_comment_y.str.contains('pChEMBL Duplication Across Documents', regex=False) & " f"{value_column}_x == {value_column}_y)" ) # Special handling for calculated pChEMBL (uses combined processing_comment column) if comment == ProcessingComment.CALCULATED_PCHEMBL.value: return "processing_comment.str.contains('Calculated pChEMBL', regex=False) & (dropping_comment == '')" # For other processing comments (uses combined processing_comment column) if comment in [ ProcessingComment.SALT_SOLVENT_REMOVED.value, ProcessingComment.CORRECTED_STANDARD_RELATION.value, ]: return f"processing_comment.str.contains('{comment}', regex=False) & dropping_comment == ''" # Special handling for unit conversion (pattern-based, uses combined processing_comment column) if comment == ProcessingComment.UNIT_CONVERTED.value: return "processing_comment.str.contains('Unit converted to', regex=False) & dropping_comment == ''" # Special handling for potential duplicates (needs to be in both assays, exclude data validity issues) if comment == DroppingComment.POTENTIAL_DUPLICATE.value: return ( "((data_dropping_comment_x.str.contains('Potential Duplicate', regex=False) & " "data_dropping_comment_y.str.contains('Potential Duplicate', regex=False)) | " "(data_dropping_comment_y.str.contains('Potential Duplicate', regex=False) & " "data_dropping_comment_x.str.contains('Potential Duplicate', regex=False))) & " "~dropping_comment.str.contains('Data Validity Comment Present', regex=False)" ) # Special handling for unit annotation errors (needs to be in both assays, exclude data validity issues) if comment == DroppingComment.UNIT_ANNOTATION_ERROR.value: return ( "((data_dropping_comment_x.str.contains('Unit Annotation Error', regex=False) & " "data_dropping_comment_y.str.contains('Unit Annotation Error', regex=False)) | " "(data_dropping_comment_y.str.contains('Unit Annotation Error', regex=False) & " "data_dropping_comment_x.str.contains('Unit Annotation Error', regex=False))) & " "~dropping_comment.str.contains('Data Validity Comment Present', regex=False)" ) # we need to compare the documents with dates to the ones without to see if it's an issue if comment == DroppingComment.MISSING_DOCUMENT_DATE.value: return ( "((data_dropping_comment_x.str.contains('Missing document date', regex=False) | " "data_dropping_comment_y.str.contains('Missing document date', regex=False)) & " "~dropping_comment.str.contains('Data Validity Comment Present', regex=False))" ) # For Data Validity Comment, allow any processing comment if comment == DroppingComment.DATA_VALIDITY_COMMENT.value: return f"dropping_comment.str.contains('{comment}', regex=False)" # For other dropping comments, exclude Data Validity Comment return ( f"dropping_comment.str.contains('{comment}', regex=False) & " "~dropping_comment.str.contains('Data Validity Comment Present', regex=False)" )
[docs] def plot_multi_panel_comparability( exploded_subset: pd.DataFrame, comments: List[str], title: str = "Comparability Across Flagged Data", figsize: Tuple[float, float] = (20, 8), ncols: int = 5, value_column: str = "pchembl_value", log_transform: bool = False, log_scale_factor: float = 1.0, axis_label: Optional[str] = None, axis_limits: Optional[Tuple[float, float]] = None, reference_lines: bool = True, units: Optional[str] = None, alpha: float = 0.3, ) -> Tuple[plt.Figure, np.ndarray]: """Create multi-panel plot showing comparability for different data quality flags. Args: exploded_subset: DataFrame from explode_assay_comparability(). comments: List of comment strings to plot. title: Overall figure title. figsize: Figure size as (width, height) tuple. ncols: Number of columns in subplot grid. value_column: Base name of the value column (without _x/_y suffix). Defaults to "pchembl_value". log_transform: If True, apply -log10 transformation to values before plotting. Use for concentration values (nM, µM, etc.) but NOT for percentages or pChEMBL. log_scale_factor: Scale factor for the unit of measurement (e.g., 1e-6 for values in 10^-6 cm/s units). The transformation becomes -log10(value * factor). axis_label: Custom axis label. If None, uses format_axis_label() to generate a LaTeX-formatted label based on value_column, log_transform, and units. axis_limits: Tuple of (min, max) for both axes. If None, defaults to (3, 12) for pchembl_value, or auto-detects from data. reference_lines: If True, draw identity and ±1/±0.3 reference lines. units: Unit string for axis labels (e.g., "10^-6 cm/s"). Converted to LaTeX format automatically. Only used when axis_label is None. Returns: Tuple of (figure, axes array). """ x_col = f"{value_column}_x" y_col = f"{value_column}_y" comments_with_data = [] # only display the comments that have data n_data = [] # debugging info only for comment in comments: if comment == "Corrected standard_relation from = to < (censored activity_comment)": continue # skip this comment as it contains discrete data only query_str = build_query_string(comment, value_column=value_column) subset = exploded_subset.query(query_str) if len(subset) > 0: comments_with_data.append(comment) n_data.append(len(subset)) if len(comments_with_data) == 0: # No data to plot fig, ax = plt.subplots(1, 1, figsize=figsize) ax.text(0.5, 0.5, "No data available for any flags", ha="center", va="center") ax.axis("off") return fig, np.array([ax]) nrows = int(np.ceil(len(comments_with_data) / ncols)) colors = [tuple([*col] + [1]) for col in colormaps["tab20"].colors] fig, axs = plt.subplots(nrows, ncols, figsize=figsize) axs_flat = axs.flatten() if nrows > 1 else [axs] if ncols == 1 else axs # Determine axis limits if axis_limits is not None: lim_min, lim_max = axis_limits elif value_column == "pchembl_value": # pChEMBL values typically range from 3-12 lim_min, lim_max = 3, 12 else: # Auto-detect from (possibly transformed) data all_x = exploded_subset[x_col].astype(float) all_y = exploded_subset[y_col].astype(float) if log_transform: offset = -np.log10(log_scale_factor) if log_scale_factor != 1.0 else 0 all_x = -np.log10(all_x + 0.001) + offset all_y = -np.log10(all_y + 0.001) + offset all_vals = np.concatenate([all_x, all_y]) data_min, data_max = np.nanmin(all_vals), np.nanmax(all_vals) margin = (data_max - data_min) * 0.1 lim_min = data_min - margin lim_max = data_max + margin # Determine axis labels if axis_label is not None: label_base = axis_label elif value_column == "pchembl_value": label_base = "pChEMBL value" else: property_name = value_column.replace("_", " ").title() label_base = format_axis_label(property_name, log_transform=log_transform, units=units) for idx, color, obs, ax in zip( range(1, len(comments_with_data) + 1), colors, comments_with_data, axs_flat ): query_str = build_query_string(obs, value_column=value_column) subset = exploded_subset.query(query_str) # Extract actual title with dynamic threshold if it's a parametric comment title_str = obs if obs.startswith("Assay size") or obs.startswith("Insufficient assay overlap"): # Find first occurrence with actual threshold from data for col in ["data_dropping_comment_x", "data_dropping_comment_y"]: if col in subset.columns: sample_comments = subset[col].dropna() if len(sample_comments) > 0: # Split by " & " since data_dropping_comment uses " & " as separator for comment in sample_comments.iloc[0].split(" & "): comment = comment.strip() if normalize_comment_pattern(comment) == obs: title_str = comment break if title_str != obs: break subset = subset.copy() subset[x_col] = subset[x_col].astype(float) subset[y_col] = subset[y_col].astype(float) xp = subset[x_col].copy() yp = subset[y_col].copy() if log_transform: offset = -np.log10(log_scale_factor) if log_scale_factor != 1.0 else 0 xp = -np.log10(xp + 0.001) + offset yp = -np.log10(yp + 0.001) + offset ax.scatter( xp, yp, alpha=alpha, edgecolors="none", label=title_str, color=color, ) ax.set_title(f"{idx}. {title_str}") # Add reference lines if reference_lines and (value_column == "pchembl_value" or log_transform): ax.plot((lim_min, lim_max), (lim_min, lim_max), "k-", label="Identity $(y=x)$") ax.plot((lim_min, lim_max), (lim_min - 1, lim_max - 1), "k--") ax.plot((lim_min, lim_max), (lim_min + 1, lim_max + 1), "k--", label="$y=x±1$") ax.plot((lim_min, lim_max), (lim_min - 0.3, lim_max - 0.3), "k-.") ax.plot((lim_min, lim_max), (lim_min + 0.3, lim_max + 0.3), "k-.", label="$y=x±0.3$") elif reference_lines: ax.plot((lim_min, lim_max), (lim_min, lim_max), "k-", label="Identity $(y=x)$") ax.set_xlim(lim_min, lim_max) ax.set_ylim(lim_min, lim_max) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) try: r, _ = stats.spearmanr(xp, yp) tau, _ = stats.kendalltau(xp, yp) r2 = r2_score(xp, yp) except ValueError: continue ax.text( 1.0, 0.175, rf"$R^2: {r2:.2f}$" + "\n" + rf"Spearman $\rho: {r:.2f}$" + "\n" + rf"Kendall $\tau: {tau:.2f}$", transform=ax.transAxes, verticalalignment="top", horizontalalignment="right", ) # Log per-panel comparability metrics is_log = value_column == "pchembl_value" or log_transform _log_comparability_metrics(xp.values, yp.values, label=title_str, is_log_scale=is_log, rho=r, r2=r2, tau=tau) if idx in [1, ncols + 1]: ax.set_ylabel(f"Assay 2 {label_base}") if idx > (nrows - 1) * ncols: ax.set_xlabel(f"Assay 1 {label_base}") if idx == len(comments_with_data): handles, labels = ax.get_legend_handles_labels() if len(handles) > 1: ax.legend( handles[1:], labels[1:], title="", bbox_to_anchor=(1.05, 0, 1, 0.2), loc="lower left", mode="expand", frameon=False, ) # Hide unused subplots for ax in axs_flat[len(comments_with_data) :]: ax.set_visible(False) fig.tight_layout() fig.suptitle(title, fontsize=16, y=1.02) return fig, axs