"""Clean aggregated bioactivity data by removing quality flags and duplicates,
and pivot into activity matrices for multitask modeling.
"""
from typing import List, Optional
import pandas as pd
from ..core.pandas_helper import assign_stats
from ..logger import logger
def clean_data(
df: pd.DataFrame,
drop_flags: Optional[List[str]] = None,
deduplicate: bool = False,
value_col: str = "pchembl_value",
resolve_annotation_error: Optional[str] = None,
) -> pd.DataFrame:
"""Clean aggregated bioactivity data by deduplicating, resolving errors, and filtering flags.
Orchestrates cleaning steps in the correct order:
1. Deduplicate (if requested): removes identical values within aggregated rows
and recalculates statistics (mean, median, std, counts).
2. Resolve annotation errors (if requested): detects measurements differing by
exactly 3.0 or 6.0 log units (unit conversion errors), keeps the earliest
document's value, then re-aggregates.
3. Drop flags: removes individual flagged measurements from aggregated rows
and recalculates statistics; rows where all measurements are flagged are
removed entirely. Non-aggregated data is filtered at the row level.
Appropriate flags for dropping include unit errors, undefined stereochemistry,
assay size issues, and mixtures. For potential duplicates, prefer using
``deduplicate=True`` instead of dropping — this keeps one measurement while
removing extras, rather than discarding the entire row.
Example::
from Capricho.cli.prepare import clean_data, prepare_multitask_data
cleaned = clean_data(df, drop_flags=["Unit Annotation Error"], deduplicate=True)
matrix = prepare_multitask_data(cleaned, task_col="target_chembl_id", ...)
Args:
df: Aggregated DataFrame from aggregate_data().
drop_flags: List of quality flags to remove. Rows where data_dropping_comment
contains any of these flags will be filtered out.
deduplicate: If True, remove duplicate values within aggregated rows
and recalculate statistics.
value_col: Column containing the activity values (e.g., "pchembl_value"
or "standard_value"). Used for deduplication and stats recalculation.
resolve_annotation_error: Resolution strategy for unit annotation errors.
Currently only "first" is supported (keep earliest document).
Cannot be used together with dropping "Unit Annotation Error" flags.
Returns:
Cleaned DataFrame.
Raises:
ValueError: If both ``resolve_annotation_error`` and "Unit Annotation Error"
in ``drop_flags`` are set, since these are contradictory operations.
"""
from ..analysis import (
DroppingComment,
deaggregate_data,
deduplicate_aggregated_values,
filter_aggregated_dropping_flags,
resolve_annotation_errors,
)
# Validate: can't both resolve and drop annotation errors
if resolve_annotation_error is not None and drop_flags:
if DroppingComment.UNIT_ANNOTATION_ERROR.value in drop_flags:
raise ValueError(
"Cannot both resolve and drop unit annotation errors. "
"Use resolve_annotation_error to fix them, or drop_flags to remove them."
)
if resolve_annotation_error is not None and resolve_annotation_error != "first":
raise ValueError(
f"Unknown resolution strategy: {resolve_annotation_error}. Only 'first' is supported."
)
df = df.copy()
input_rows = len(df)
# Track steps for summary
dedup_removed = 0
annotation_removed = 0
rows_after_dedup = input_rows
rows_after_annotation = input_rows
# Step 1: Deduplicate
if deduplicate:
logger.info("Deduplicating identical values within aggregated rows...")
initial_total = df[value_col].apply(lambda x: len(str(x).split("|")) if pd.notna(x) else 0).sum()
df = deduplicate_aggregated_values(df, value_column=value_col)
final_total = df[value_col].apply(lambda x: len(str(x).split("|")) if pd.notna(x) else 0).sum()
dedup_removed = initial_total - final_total
logger.info(f"Deduplication removed {dedup_removed} duplicate values")
logger.info("Recalculating statistics after deduplication...")
df = assign_stats(df, value_col=value_col, use_geometric=(value_col == "pchembl_value"))
rows_after_dedup = len(df)
# Step 2: Resolve annotation errors
if resolve_annotation_error is not None:
logger.info("Resolving unit annotation errors (3.0 or 6.0 log unit differences)...")
initial_rows = len(df)
exploded = deaggregate_data(df)
logger.info(f"Exploded {initial_rows} aggregated rows into {len(exploded)} individual measurements")
resolved = resolve_annotation_errors(
exploded,
strategy=resolve_annotation_error,
value_col=value_col,
)
annotation_removed = len(exploded) - len(resolved)
logger.info(f"Removed {annotation_removed} measurements due to annotation error resolution")
# Re-aggregate the data
from .chembl_data_pipeline import re_aggregate_data
# Detect extra_id_cols from columns between connectivity and smiles
cols = list(df.columns)
if "connectivity" in cols and "smiles" in cols:
conn_idx = cols.index("connectivity")
smiles_idx = cols.index("smiles")
detected_id_cols = cols[conn_idx + 1 : smiles_idx]
logger.info(f"Detected id_columns for re-aggregation: {detected_id_cols}")
else:
detected_id_cols = []
df = re_aggregate_data(
resolved,
chirality=False,
extra_id_cols=detected_id_cols,
compound_equality="connectivity",
)
logger.info(f"Re-aggregated to {len(df)} rows")
rows_after_annotation = len(df)
# Step 3: Drop flags (measurement-level for aggregated data)
rows_before_flags = len(df)
if drop_flags:
df = filter_aggregated_dropping_flags(df, drop_flags, value_column=value_col)
# Log consolidated summary
lines = ["", "PREPARATION SUMMARY"]
lines.append(f" Input rows: {input_rows:>8,}")
if deduplicate:
lines.append(
f" After deduplication: {rows_after_dedup:>8,} (removed {dedup_removed} duplicate values)"
)
if resolve_annotation_error is not None:
lines.append(
f" After annotation resolution:{rows_after_annotation:>7,} (removed {annotation_removed} measurements)"
)
if drop_flags:
rows_removed_by_flags = rows_before_flags - len(df)
lines.append(f" After flag filtering: {len(df):>8,} (removed {rows_removed_by_flags} rows)")
lines.append(f" Final rows: {len(df):>8,}")
logger.info("\n".join(lines))
return df
[docs]
def prepare_multitask_data(
df: pd.DataFrame,
task_col: str,
value_col: str,
compound_col: str,
smiles_col: str,
id_columns: Optional[List[str]] = None,
) -> pd.DataFrame:
"""Transform aggregated data to multitask format (activity matrix).
This function pivots aggregated bioactivity data to create an activity matrix where:
- Rows represent unique compounds (identified by compound_col)
- Columns represent tasks (e.g., different targets)
- Values are the bioactivity measurements (e.g., pchembl_value_mean)
Use ``clean_data()`` before calling this function to filter quality flags
and deduplicate values.
Args:
df: Aggregated DataFrame from aggregate_data() with bioactivity statistics.
task_col: Column to use as task identifier (e.g., "target_chembl_id").
value_col: Column containing values to pivot (e.g., "pchembl_value_mean").
compound_col: Column for compound identity (e.g., "connectivity" or "smiles").
smiles_col: Column containing SMILES strings.
id_columns: List of additional columns to combine with task_col for creating
composite task identifiers. Use this when data was aggregated with
--id-columns (e.g., ["assay_tissue"]) to prevent losing information.
Returns:
DataFrame with compounds as rows (indexed by compound_col), tasks as columns,
and a smiles column. Missing values are represented as NaN.
"""
df = df.copy()
# Validate required columns exist
required_cols = [compound_col, task_col, value_col, smiles_col]
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
# Create composite task column if id_columns are provided
effective_task_col = task_col
if id_columns is not None and len(id_columns) > 0:
missing_id_cols = [col for col in id_columns if col not in df.columns]
if missing_id_cols:
raise ValueError(f"Missing id_columns: {missing_id_cols}")
# Create composite task identifier by joining task_col with id_columns
composite_name = "_composite_task"
df[composite_name] = df[task_col].astype(str)
for col in id_columns:
df[composite_name] = df[composite_name] + "-" + df[col].fillna("").astype(str)
effective_task_col = composite_name
logger.info(f"Created composite task column from: {task_col} + {id_columns}")
# Check for duplicates (multiple values per compound-task pair)
dup_check = df.groupby([compound_col, effective_task_col]).size()
duplicates = dup_check[dup_check > 1]
if len(duplicates) > 0:
n_dup_pairs = len(duplicates)
n_extra_rows = duplicates.sum() - n_dup_pairs
logger.warning(
f"Found {n_dup_pairs} compound-task pairs with multiple values ({n_extra_rows} extra rows). "
f"Only the first value will be kept. "
f"If your data was aggregated with --id-columns, use the same columns here via --id-columns."
)
# Pivot the data to create activity matrix
logger.info(
f"Creating activity matrix with {df[compound_col].nunique()} compounds "
f"and {df[effective_task_col].nunique()} tasks"
)
activity_matrix = df.pivot_table(
index=compound_col,
columns=effective_task_col,
values=value_col,
aggfunc="first",
)
# Reset the columns name to remove the task_col label
activity_matrix.columns.name = None
# Add SMILES column back by taking the first SMILES for each compound
if smiles_col != compound_col:
smiles_map = df.groupby(compound_col)[smiles_col].first()
activity_matrix[smiles_col] = smiles_map
# Compute sparsity (fraction of NaN cells, excluding smiles column)
task_cols = [c for c in activity_matrix.columns if c != smiles_col]
if task_cols:
n_cells = activity_matrix[task_cols].size
n_missing = activity_matrix[task_cols].isna().sum().sum()
sparsity = n_missing / n_cells * 100 if n_cells > 0 else 0.0
else:
sparsity = 0.0
logger.info(
f"Activity matrix: {activity_matrix.shape[0]} compounds x {len(task_cols)} tasks "
f"(sparsity: {sparsity:.1f}%)"
)
return activity_matrix