"""Contain functions for binarizing bioactivity data; handling censored data and validating agreement between discrete and censored measurements"""
import json
from pathlib import Path
from typing import Optional
import numpy as np
import pandas as pd
from scipy.stats import gmean, gstd
from ..core.pandas_helper import add_comment
from ..logger import logger
def _calculate_mcc(tp: int, tn: int, fp: int, fn: int) -> float:
"""Calculate Matthews Correlation Coefficient from confusion matrix values.
MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
Args:
tp: True positives
tn: True negatives
fp: False positives
fn: False negatives
Returns:
MCC value between -1.0 and 1.0, or 0.0 if denominator is zero
"""
numerator = (tp * tn) - (fp * fn)
denominator_parts = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
if denominator_parts == 0:
return 0.0
import math
denominator = math.sqrt(denominator_parts)
return numerator / denominator
def _calculate_assay_compatibility_mcc(
df: pd.DataFrame,
compound_id_col: str,
target_id_col: str,
output_binary_col: str,
) -> float:
"""Calculate MCC measuring agreement between measurements for same compound-target pairs.
Uses pairwise comparisons: for each compound-target pair with multiple measurements,
compares all pairs of binary labels to build confusion matrix, then calculates MCC.
Args:
df: DataFrame with binarized data
compound_id_col: Column name for compound IDs
target_id_col: Column name for target IDs
output_binary_col: Column name with binary labels
Returns:
MCC value between -1.0 and 1.0
"""
groupby_cols = [compound_id_col, target_id_col]
if "mutation" in df.columns:
groupby_cols.append("mutation")
tp = tn = fp = fn = 0
for _, group_df in df.groupby(groupby_cols):
binary_labels = group_df[output_binary_col].dropna().tolist()
# Skip if less than 2 measurements
if len(binary_labels) < 2:
continue
# Pairwise comparisons within this group
for i in range(len(binary_labels)):
for j in range(i + 1, len(binary_labels)):
label_i = binary_labels[i]
label_j = binary_labels[j]
# Treat label_i as "truth" and label_j as "prediction"
if label_i == 1 and label_j == 1:
tp += 1
elif label_i == 0 and label_j == 0:
tn += 1
elif label_i == 0 and label_j == 1:
fp += 1
elif label_i == 1 and label_j == 0:
fn += 1
return _calculate_mcc(tp, tn, fp, fn)
def _truncate_dataframe(df: pd.DataFrame, limit: int = 15) -> pd.DataFrame:
"""Truncate DataFrame values to a specified length for display purposes.
Args:
df: DataFrame to truncate
limit: Maximum character length for each cell value
Returns:
DataFrame with truncated string values
"""
if pd.__version__ > "2.1.0": # applymap got deprecated in 2.1.0
return df.map(lambda x: str(x)[:limit] + "..." if len(str(x)) > limit else str(x))
else:
return df.applymap(lambda x: str(x)[:limit] + "..." if len(str(x)) > limit else str(x))
[docs]
def invert_relation_for_pchembl(relation: str) -> str:
"""Inverts comparison relation for pchembl values.
Since pchembl = -log10(Molar), higher pchembl = more active (lower concentration).
Therefore, standard_relation directions must be inverted:
- standard_relation "<" (low concentration, active) → pchembl ">" (high value, active)
- standard_relation ">" (high concentration, inactive) → pchembl "<" (low value, inactive)
Args:
relation: Original standard_relation from ChEMBL ("=", "<", ">", "<=", ">=", "~", ">>", "<<")
Returns:
Inverted relation for pchembl comparison
"""
inversion_map = {
"<": ">",
">": "<",
"<=": ">=",
">=": "<=",
"=": "=",
"~": "~",
">>": "<<",
"<<": ">>",
}
if relation not in inversion_map:
raise ValueError(f"Unknown relation: {relation}")
return inversion_map[relation]
def _classify_by_relation(value: float, relation: str, threshold: float) -> int:
"""Classify a single measurement as active (1) or inactive (0) based on relation type.
Args:
value: pchembl_value to classify
relation: standard_relation ("=", "~", "<", ">", "<=", ">=", "<<", ">>")
threshold: Activity threshold for binarization
Returns:
1 for active, 0 for inactive
"""
if relation == "=":
return 1 if value >= threshold else 0
elif relation == "~":
lower_bound = value - 0.5
return 1 if lower_bound >= threshold else 0
elif relation in ["<", "<=", "<<"]:
return 1 if value >= threshold else 0
elif relation in [">", ">=", ">>"]:
return 0 if value <= threshold else 1
else:
raise ValueError(f"Unknown relation: {relation}")
VALID_CONFLICT_STRATEGIES = {"drop", "relation", "confidence", "majority"}
CENSORED_RELATIONS = {"<", ">", "<=", ">=", "<<", ">>"}
def _max_confidence_score(confidence_str) -> int:
"""Parse a pipe-separated confidence_score string and return the maximum value.
Args:
confidence_str: A single value like "9" or pipe-separated like "8|9", or NaN
Returns:
Maximum confidence score as integer, or 0 for NaN/empty
"""
if pd.isna(confidence_str) or str(confidence_str).strip() == "":
return 0
parts = str(confidence_str).split("|")
try:
return max(int(p.strip()) for p in parts if p.strip())
except ValueError:
return 0
def _resolve_conflicts(
df: pd.DataFrame,
conflict_indices: list,
strategy: str,
compound_id_col: str,
target_id_col: str,
relation_col: str,
output_binary_col: str,
groupby_cols: list[str],
value_column: str = "pchembl_value_mean",
threshold: Optional[float] = None,
) -> tuple[list, list[dict]]:
"""Apply a conflict resolution strategy to conflicting compound-target groups.
Args:
df: DataFrame with binarized data
conflict_indices: List of row indices involved in conflicts
strategy: One of "drop", "relation", "confidence", "majority"
compound_id_col: Column name for compound IDs
target_id_col: Column name for target IDs
relation_col: Column with standard_relation values
output_binary_col: Column with binary labels
groupby_cols: Columns to group by for conflict detection
value_column: Column used for binarization (used to derive counts column)
threshold: Binarization threshold (used for measurement-level majority voting)
Returns:
Tuple of (indices_to_drop, resolution_details) where resolution_details is
a list of dicts keyed by (compound_id, target_id) with resolution info
"""
if strategy == "confidence" and "confidence_score" not in df.columns:
raise ValueError(
"Strategy 'confidence' requires a 'confidence_score' column in the DataFrame"
)
# Derive raw value column and counts column from value_column
base = value_column.rsplit("_", 1)[0] if "_" in value_column else value_column
counts_col = None
raw_value_col = None
if strategy == "majority":
candidate_counts = f"{base}_counts"
if candidate_counts in df.columns:
counts_col = candidate_counts
if base in df.columns and base != value_column:
raw_value_col = base
conflict_subset = df.loc[conflict_indices]
indices_to_drop = []
resolution_details = []
for group_key, group_df in conflict_subset.groupby(groupby_cols):
has_conflict, _ = _detect_conflicts(group_df, output_binary_col)
if not has_conflict:
continue
if strategy == "drop":
drop_idx, detail = _resolve_drop_group(group_key, group_df)
elif strategy == "relation":
drop_idx, detail = _resolve_relation_group(group_key, group_df, relation_col)
elif strategy == "confidence":
drop_idx, detail = _resolve_confidence_group(group_key, group_df)
elif strategy == "majority":
drop_idx, detail = _resolve_majority_group(
group_key,
group_df,
output_binary_col,
counts_col=counts_col,
raw_value_col=raw_value_col,
relation_col=relation_col,
threshold=threshold,
)
else:
raise ValueError(f"Unknown conflict resolution strategy: '{strategy}'")
indices_to_drop.extend(drop_idx)
resolution_details.append(detail)
return indices_to_drop, resolution_details
def _resolve_drop_group(group_key, group_df) -> tuple[list, dict]:
"""Drop all rows in a conflicting group."""
return group_df.index.tolist(), {
"group_key": group_key,
"strategy": "drop",
"outcome": "dropped_all",
"rows_kept": 0,
"rows_dropped": len(group_df),
"detail": f"Dropped all {len(group_df)} rows",
}
def _resolve_relation_group(group_key, group_df, relation_col) -> tuple[list, dict]:
"""Keep '=' rows and drop censored rows. Fall back to dropping all if no '=' exists."""
exact_mask = group_df[relation_col] == "="
exact_rows = group_df[exact_mask]
if len(exact_rows) == 0:
logger.warning(
f"Conflict for {group_key}: no exact (=) measurements found, dropping all rows"
)
return group_df.index.tolist(), {
"group_key": group_key,
"strategy": "relation",
"outcome": "dropped_all",
"rows_kept": 0,
"rows_dropped": len(group_df),
"detail": "No exact (=) measurements found, dropped all rows",
}
censored_indices = group_df[~exact_mask].index.tolist()
return censored_indices, {
"group_key": group_key,
"strategy": "relation",
"outcome": "kept_exact",
"rows_kept": len(exact_rows),
"rows_dropped": len(censored_indices),
"detail": f"Kept {len(exact_rows)} exact (=) rows, dropped {len(censored_indices)} censored rows",
}
def _resolve_confidence_group(group_key, group_df) -> tuple[list, dict]:
"""Keep row(s) with highest confidence_score. Drop all on tie across binary labels."""
group_df = group_df.copy()
group_df["_max_conf"] = group_df["confidence_score"].apply(_max_confidence_score)
max_conf = group_df["_max_conf"].max()
winners = group_df[group_df["_max_conf"] == max_conf]
# If winners still have conflicting labels, it's a tie → drop all
if len(winners["activity_binary"].dropna().unique()) > 1:
return group_df.index.tolist(), {
"group_key": group_key,
"strategy": "confidence",
"outcome": "dropped_all",
"rows_kept": 0,
"rows_dropped": len(group_df),
"detail": f"Tied confidence scores ({max_conf}) with conflicting labels, dropped all rows",
}
losers = group_df[group_df["_max_conf"] != max_conf].index.tolist()
outcome_label = "active" if winners.iloc[0]["activity_binary"] == 1 else "inactive"
return losers, {
"group_key": group_key,
"strategy": "confidence",
"outcome": f"kept_{outcome_label}",
"rows_kept": len(winners),
"rows_dropped": len(losers),
"detail": f"Kept {len(winners)} rows with confidence_score={max_conf} ({outcome_label})",
}
def _resolve_majority_group(
group_key,
group_df,
output_binary_col,
counts_col: Optional[str] = None,
raw_value_col: Optional[str] = None,
relation_col: str = "standard_relation",
threshold: Optional[float] = None,
) -> tuple[list, dict]:
"""Keep rows matching the majority binary label, weighted by measurement count.
When the raw pipe-separated value column is available (e.g. pchembl_value), each
individual measurement is classified against the threshold and gets one vote.
Otherwise, falls back to count-weighted or row-based voting.
"""
valid = group_df[output_binary_col].dropna()
if len(valid.unique()) < 2:
return [], {
"group_key": group_key,
"strategy": "majority",
"outcome": "no_conflict",
"rows_kept": len(group_df),
"rows_dropped": 0,
"detail": "No conflict after dropping NaN labels",
}
# Measurement-level voting: split pipe-separated values, classify each individually
use_measurement_level = (
raw_value_col
and raw_value_col in group_df.columns
and relation_col in group_df.columns
and threshold is not None
)
if use_measurement_level:
active_weight = 0.0
inactive_weight = 0.0
for _, row in group_df.iterrows():
raw_str = str(row[raw_value_col])
relation = str(row[relation_col])
if pd.isna(row[raw_value_col]) or raw_str == "nan":
continue
values = raw_str.split("|")
for v in values:
try:
label = _classify_by_relation(float(v), relation, threshold)
if label == 1:
active_weight += 1
else:
inactive_weight += 1
except (ValueError, TypeError):
continue
weight_label = "individual measurements"
elif counts_col and counts_col in group_df.columns:
weights = group_df[counts_col].fillna(1).astype(float)
active_weight = float(weights[group_df[output_binary_col] == 1].sum())
inactive_weight = float(weights[group_df[output_binary_col] == 0].sum())
weight_label = "measurements"
else:
active_weight = float((group_df[output_binary_col] == 1).sum())
inactive_weight = float((group_df[output_binary_col] == 0).sum())
weight_label = "rows"
active_weight = float(active_weight)
inactive_weight = float(inactive_weight)
if active_weight == inactive_weight:
return group_df.index.tolist(), {
"group_key": group_key,
"strategy": "majority",
"outcome": "dropped_all",
"rows_kept": 0,
"rows_dropped": len(group_df),
"detail": f"Tied votes: {active_weight:.0f} active vs {inactive_weight:.0f} inactive "
f"{weight_label}, dropped all rows",
}
if active_weight > inactive_weight:
majority_label = 1
majority_weight = active_weight
minority_weight = inactive_weight
else:
majority_label = 0
majority_weight = inactive_weight
minority_weight = active_weight
minority_indices = group_df[group_df[output_binary_col] != majority_label].index.tolist()
kept = len(group_df) - len(minority_indices)
outcome_label = "active" if majority_label == 1 else "inactive"
return minority_indices, {
"group_key": group_key,
"strategy": "majority",
"outcome": f"kept_{outcome_label}",
"rows_kept": kept,
"rows_dropped": len(minority_indices),
"detail": f"Majority vote: {majority_weight:.0f} {outcome_label} vs {minority_weight:.0f} "
f"{weight_label}, dropped {len(minority_indices)} minority rows",
}
def _classify_conflict_pattern(measurements: list[dict]) -> str:
"""Classify a conflict as exact_vs_censored or censored_vs_censored.
Args:
measurements: List of measurement dicts with 'standard_relation' key
Returns:
"exact_vs_censored" or "censored_vs_censored"
"""
relations = {m.get("standard_relation", "=") for m in measurements}
has_exact = "=" in relations
has_censored = bool(relations & CENSORED_RELATIONS)
if has_exact and has_censored:
return "exact_vs_censored"
return "censored_vs_censored"
def _detect_conflicts(
group_df: pd.DataFrame,
output_binary_col: str,
) -> tuple[bool, str]:
"""Detect conflicts within a compound-target group based on binarization outcomes.
A conflict occurs when multiple measurements for the same compound-target pair
result in different binary labels (active vs inactive).
Args:
group_df: DataFrame subset for one compound-target pair
output_binary_col: Column name with binary labels
Returns:
Tuple of (has_conflict, conflict_type) where conflict_type is "binary_label_mismatch" or ""
"""
binary_labels = group_df[output_binary_col].dropna()
if len(binary_labels) > 1 and len(binary_labels.unique()) > 1:
return True, "binary_label_mismatch"
return False, ""
def _generate_conflict_details(
df: pd.DataFrame,
conflict_indices: list,
compound_id_col: str,
target_id_col: str,
pchembl_relation_col: str,
value_column: str,
output_binary_col: str,
threshold: float,
) -> list[dict]:
"""Generate detailed conflict information for each compound-target pair.
Args:
df: DataFrame with binarized data
conflict_indices: List of row indices with conflicts
compound_id_col: Column name for compound IDs
target_id_col: Column name for target IDs
pchembl_relation_col: Column name for pchembl_relation
value_column: Column name for pchembl values
output_binary_col: Column name for binary labels
threshold: Binarization threshold
Returns:
List of conflict detail dictionaries
"""
if not conflict_indices:
return []
conflict_subset = df.loc[conflict_indices].copy()
groupby_cols = [compound_id_col, target_id_col]
if "mutation" in conflict_subset.columns:
groupby_cols.append("mutation")
conflict_details = []
for group_key, group_df in conflict_subset.groupby(groupby_cols):
has_conflict, conflict_type = _detect_conflicts(group_df, output_binary_col)
if not has_conflict:
continue
measurements = []
for _, row in group_df.iterrows():
measurement = {
"value": float(row[value_column]) if not pd.isna(row[value_column]) else None,
"pchembl_relation": row[pchembl_relation_col],
"binary": int(row[output_binary_col]) if not pd.isna(row[output_binary_col]) else None,
}
if "standard_relation" in row and not pd.isna(row["standard_relation"]):
measurement["standard_relation"] = row["standard_relation"]
if "assay_chembl_id" in row:
measurement["assay"] = row["assay_chembl_id"]
if "molecule_chembl_id" in row:
measurement["molecule"] = row["molecule_chembl_id"]
measurements.append(measurement)
conflict_detail = {
"compound_id": group_key[0],
"target_id": group_key[1],
"conflict_type": conflict_type,
"measurements": measurements,
"threshold": threshold,
}
if len(groupby_cols) > 2:
conflict_detail["mutation"] = group_key[2]
# Group measurements by binary outcome
active_measurements = [m for m in measurements if m["binary"] == 1]
inactive_measurements = [m for m in measurements if m["binary"] == 0]
# Create vote summary
vote_summary = {
"active_votes": len(active_measurements),
"inactive_votes": len(inactive_measurements),
}
conflict_detail["vote_summary"] = vote_summary
# Build explanation with vote counts and measurement details
explanation_parts = []
if active_measurements:
active_strs = [f"{m['pchembl_relation']}{m['value']:.2f}" for m in active_measurements]
explanation_parts.append(f"Active ({len(active_measurements)} votes): {', '.join(active_strs)}")
if inactive_measurements:
inactive_strs = [f"{m['pchembl_relation']}{m['value']:.2f}" for m in inactive_measurements]
explanation_parts.append(
f"Inactive ({len(inactive_measurements)} votes): {', '.join(inactive_strs)}"
)
conflict_detail["explanation"] = " | ".join(explanation_parts) + f" | Threshold: {threshold}"
# Severity: distance and spread metrics
values = [m["value"] for m in measurements if m["value"] is not None]
if values:
spread = max(values) - min(values)
max_distance = max(abs(v - threshold) for v in values)
if spread < 1.0:
classification = "low"
elif spread <= 2.0:
classification = "medium"
else:
classification = "high"
conflict_detail["severity"] = {
"max_distance_from_threshold": round(max_distance, 4),
"measurement_spread": round(spread, 4),
"classification": classification,
}
# Recommendation based on relation types
relations = {m.get("standard_relation", "=") for m in measurements}
has_exact = "=" in relations
has_censored = bool(relations & CENSORED_RELATIONS)
if has_exact and has_censored:
conflict_detail["recommendation"] = (
"Exact measurement (=) is generally more reliable than censored bounds"
)
else:
conflict_detail["recommendation"] = (
"Manual review recommended -- all measurements are of the same type"
)
conflict_details.append(conflict_detail)
return conflict_details
def _log_and_flag_conflicts(
df: pd.DataFrame,
conflict_indices: list,
compound_id_col: str,
target_id_col: str,
relation_col: str,
value_column: str,
output_binary_col: str,
conflict_resolution: Optional[str] = None,
) -> pd.DataFrame:
"""Log conflict details and flag conflicting rows in the DataFrame.
Args:
df: DataFrame with binarized data
conflict_indices: List of row indices with conflicts
compound_id_col: Column name for compound IDs
target_id_col: Column name for target IDs
relation_col: Column name for relations
value_column: Column name for pchembl values
output_binary_col: Column name for binary labels
conflict_resolution: If set, suppresses the detail table since conflicts will be resolved
Returns:
DataFrame with conflicts flagged via add_comment()
"""
if not conflict_indices:
return df
if conflict_resolution:
logger.warning(
f"Found {len(conflict_indices)} measurements with disagreements. "
f"Resolving with strategy '{conflict_resolution}'."
)
else:
logger.warning(
f"Found {len(conflict_indices)} measurements with disagreements. "
"These compound-target pairs have inconsistent measurements across different relation types."
)
conflict_subset = df.loc[conflict_indices].copy()
logging_cols = [compound_id_col, target_id_col, relation_col, value_column, "mutation"]
optional_cols = ["molecule_chembl_id", "assay_chembl_id", output_binary_col, "standard_relation"]
for col in optional_cols:
if col in conflict_subset.columns:
logging_cols.append(col)
logging_cols = [col for col in logging_cols if col in conflict_subset.columns]
conflict_display = conflict_subset[logging_cols].sort_values(
by=[target_id_col, compound_id_col, value_column], ascending=[True, True, False]
)
truncated_df = _truncate_dataframe(conflict_display, limit=15)
logger.warning(
f"Sample of conflicting measurements (showing up to 20 rows):\n"
f"{truncated_df.head(20).to_string(index=False)}"
)
df = add_comment(
df=df,
comment="Non-agreeing discrete and censored values",
criteria_func=lambda x: x.index.isin(conflict_indices),
target_column=value_column,
comment_type="d",
)
return df
[docs]
def save_conflict_report(
conflict_details: list[dict],
output_path: str | Path,
threshold: float,
total_rows: int = 0,
active_count: int = 0,
inactive_count: int = 0,
mcc: float = 0.0,
resolution_details: Optional[list[dict]] = None,
conflict_resolution: Optional[str] = None,
) -> None:
"""Save conflict report to JSON file.
Args:
conflict_details: List of conflict detail dictionaries
output_path: Path to save the JSON file
threshold: Binarization threshold used
total_rows: Total number of rows in the DataFrame
active_count: Number of active rows
inactive_count: Number of inactive rows
mcc: Matthews Correlation Coefficient
resolution_details: Resolution details from _resolve_conflicts
conflict_resolution: Strategy name used for resolution
"""
# Count conflict patterns
pattern_counts = {"exact_vs_censored": 0, "censored_vs_censored": 0}
for conflict in conflict_details:
pattern = _classify_conflict_pattern(conflict.get("measurements", []))
pattern_counts[pattern] += 1
summary = {
"total_rows": total_rows,
"total_conflicts": len(conflict_details),
"threshold": threshold,
"active_count": active_count,
"inactive_count": inactive_count,
"mcc": round(mcc, 4),
"conflict_patterns": pattern_counts,
}
# Add resolution summary if a strategy was used
if conflict_resolution and resolution_details is not None:
resolved = sum(1 for d in resolution_details if d["rows_dropped"] > 0)
unresolved = sum(1 for d in resolution_details if d["rows_dropped"] == 0)
total_dropped = sum(d["rows_dropped"] for d in resolution_details)
summary["resolution_summary"] = {
"strategy": conflict_resolution,
"conflicts_resolved": resolved,
"conflicts_unresolved": unresolved,
"total_rows_dropped": total_dropped,
}
# Add per-conflict resolution info
if conflict_resolution and resolution_details is not None:
detail_by_key = {tuple(d["group_key"]): d for d in resolution_details}
for conflict in conflict_details:
key = (conflict["compound_id"], conflict["target_id"])
if "mutation" in conflict:
key = (*key, conflict["mutation"])
if key in detail_by_key:
d = detail_by_key[key]
conflict["resolution"] = {
"strategy": d["strategy"],
"outcome": d["outcome"],
"rows_kept": d["rows_kept"],
"rows_dropped": d["rows_dropped"],
"detail": d["detail"],
}
report = {
"summary": summary,
"conflicts": conflict_details,
}
output_path = Path(output_path)
with open(output_path, "w") as f:
json.dump(report, f, indent=2)
logger.info(f"Conflict report saved to {output_path}")
def _deduplicate_resolved_groups(
df: pd.DataFrame,
groupby_cols: list[str],
raw_value_col: str,
relation_col: str,
output_binary_col: str,
threshold: float,
use_geometric: bool = True,
) -> pd.DataFrame:
"""Merge compound-target groups into one row per group, filtering disagreeing measurements.
For each group with multiple rows (or rows with mixed individual measurements),
splits pipe-separated values, classifies each measurement against the threshold,
and keeps only measurements agreeing with the row's binary label. Then merges
all rows in the group into a single row.
Args:
df: DataFrame after conflict resolution
groupby_cols: Columns defining compound-target groups
raw_value_col: Column with pipe-separated raw values (e.g. "pchembl_value")
relation_col: Column with standard_relation
output_binary_col: Column with binary labels
threshold: Binarization threshold
use_geometric: If True, use geometric mean/std for stats recalculation
Returns:
DataFrame with one row per compound-target group
"""
if raw_value_col not in df.columns:
return df
# Detect pipe-separated columns
pipe_cols = []
for col in df.columns:
if df[col].astype(str).str.contains("|", regex=False).any():
pipe_cols.append(col)
# Derive stats column names
stats_suffixes = ["_mean", "_std", "_median", "_counts"]
stats_cols = [f"{raw_value_col}{s}" for s in stats_suffixes]
stats_cols = [c for c in stats_cols if c in df.columns]
merged_rows = []
indices_processed = set()
for _, group_df in df.groupby(groupby_cols):
if len(group_df) == 1:
row = group_df.iloc[0]
binary_label = row[output_binary_col]
raw_str = str(row[raw_value_col])
# Skip if no binary label or no raw values
if pd.isna(binary_label) or pd.isna(row[raw_value_col]) or raw_str == "nan":
continue
# Single row: still filter disagreeing measurements within it
if "|" in raw_str:
values = raw_str.split("|")
relation = str(row[relation_col])
keep_indices = []
for i, v in enumerate(values):
try:
label = _classify_by_relation(float(v), relation, threshold)
if label == int(binary_label):
keep_indices.append(i)
except (ValueError, TypeError):
keep_indices.append(i) # Keep on error
if len(keep_indices) == 0:
indices_processed.update(group_df.index.tolist())
continue # All measurements disagree, drop the group
if len(keep_indices) < len(values):
# Filter pipe-separated columns at the same positions
new_row = row.copy()
for col in pipe_cols:
col_values = str(new_row[col]).split("|")
if len(col_values) == len(values):
new_row[col] = "|".join(col_values[i] for i in keep_indices)
# Expand standard_relation to pipe-separated to match
new_row[relation_col] = "|".join(relation for _ in keep_indices)
# Recalculate stats
kept_values = [float(values[i]) for i in keep_indices]
new_row = _recalculate_stats(
new_row, raw_value_col, kept_values, use_geometric
)
merged_rows.append(new_row)
indices_processed.update(group_df.index.tolist())
continue
# No filtering needed for single-value rows
continue
# Multi-row group: merge all rows into one
binary_label = group_df[output_binary_col].dropna().iloc[0]
all_kept_values = []
all_kept_relations = []
# For each pipe-separated column, collect kept entries
pipe_col_kept = {col: [] for col in pipe_cols}
for _, row in group_df.iterrows():
raw_str = str(row[raw_value_col])
relation = str(row[relation_col])
if pd.isna(row[raw_value_col]) or raw_str == "nan":
continue
values = raw_str.split("|")
n_values = len(values)
relations = [relation] * n_values # Expand single relation to match count
keep_indices = []
for i, v in enumerate(values):
try:
label = _classify_by_relation(float(v), relations[i], threshold)
if label == int(binary_label):
keep_indices.append(i)
except (ValueError, TypeError):
keep_indices.append(i)
for i in keep_indices:
all_kept_values.append(values[i])
all_kept_relations.append(relations[i])
# Filter corresponding positions in all pipe-separated columns
for col in pipe_cols:
col_values = str(row[col]).split("|")
if len(col_values) == n_values:
pipe_col_kept[col].extend(col_values[i] for i in keep_indices)
else:
# Column doesn't align with raw values; just concatenate all
pipe_col_kept[col].extend(col_values)
if not all_kept_values:
indices_processed.update(group_df.index.tolist())
continue # All measurements filtered out, drop the group
# Build merged row from first row as template
new_row = group_df.iloc[0].copy()
new_row[output_binary_col] = binary_label
# Set pipe-separated columns to merged values
for col in pipe_cols:
if pipe_col_kept[col]:
new_row[col] = "|".join(pipe_col_kept[col])
# Set standard_relation to pipe-separated per-measurement relations
new_row[relation_col] = "|".join(all_kept_relations)
# Recalculate stats
kept_floats = [float(v) for v in all_kept_values]
new_row = _recalculate_stats(new_row, raw_value_col, kept_floats, use_geometric)
merged_rows.append(new_row)
indices_processed.update(group_df.index.tolist())
if not merged_rows and not indices_processed:
return df
# Build result: keep unprocessed rows as-is, append merged rows
unprocessed = df.loc[~df.index.isin(indices_processed)]
if merged_rows:
merged_df = pd.DataFrame(merged_rows)
result = pd.concat([unprocessed, merged_df], ignore_index=True)
else:
result = unprocessed.reset_index(drop=True)
return result
def _recalculate_stats(
row: pd.Series,
raw_value_col: str,
values: list[float],
use_geometric: bool,
) -> pd.Series:
"""Recalculate mean/std/median/counts stats for a row from a list of values.
Calculates directly instead of using assign_stats to avoid the censored-row
override (which doesn't work with mixed pipe-separated relations).
"""
n = len(values)
arr = np.array(values)
if use_geometric:
mean_val = float(gmean(10 ** (-arr)))
mean_val = -np.log10(mean_val)
if n > 1:
std_val = float(gstd(10 ** (-arr)))
else:
std_val = np.nan
median_val = float(np.median(arr))
else:
mean_val = float(np.mean(arr))
std_val = float(np.std(arr)) if n > 1 else np.nan
median_val = float(np.median(arr))
mean_col = f"{raw_value_col}_mean"
std_col = f"{raw_value_col}_std"
median_col = f"{raw_value_col}_median"
counts_col = f"{raw_value_col}_counts"
if mean_col in row.index:
row[mean_col] = mean_val
if std_col in row.index:
row[std_col] = std_val
if median_col in row.index:
row[median_col] = median_val
if counts_col in row.index:
row[counts_col] = n
return row
[docs]
def binarize_aggregated_data(
df: pd.DataFrame,
threshold: float = 6.0,
value_column: str = "pchembl_value_mean",
compound_id_col: str = "connectivity",
target_id_col: str = "target_chembl_id",
relation_col: str = "standard_relation",
output_binary_col: str = "activity_binary",
compare_across_mutants: bool = False,
conflict_report_path: Optional[str | Path] = None,
conflict_resolution: Optional[str] = None,
) -> pd.DataFrame:
"""Binarize aggregated bioactivity data based on activity threshold and standard_relation.
This function converts continuous pchembl values to binary labels (0=inactive, 1=active)
while properly handling censored measurements and approximate values, and validating
agreement between discrete and censored measurements for the same compound-target pair.
Key logic:
- standard_relation "=": compare value to threshold directly
- standard_relation "~": approximate (±0.5 log units); uses lower bound for conservative classification
- standard_relation "<", "<<" (low concentration): if pchembl >= threshold → active (1)
- standard_relation ">", ">>" (high concentration): if pchembl <= threshold → inactive (0)
- Mixed relations: validate agreement and flag conflicts
Args:
df: Aggregated DataFrame from aggregate_data() with pchembl statistics
threshold: Activity threshold for binarization (default 6.0 = 1 µM)
value_column: Which aggregated column to use (default: "pchembl_value_mean")
compound_id_col: Column identifying compounds (default: "connectivity")
target_id_col: Column identifying targets (default: "target_chembl_id")
relation_col: Column with standard_relation values (default: "standard_relation")
output_binary_col: Name for output binary column (default: "activity_binary")
compare_across_mutants: If False (default), different mutations are treated as separate
compound-target pairs for conflict detection. If True, measurements on different
mutants are compared and flagged if they disagree.
conflict_report_path: Optional path to save detailed conflict report as JSON
conflict_resolution: Strategy for resolving conflicts. One of:
- None (default): flag only, keep all rows
- "drop": remove all rows for conflicting pairs
- "relation": keep '=' rows, drop censored; fall back to drop if no '='
- "confidence": keep row with highest confidence_score; drop all on tie
- "majority": keep rows matching majority binary label; drop all on tie
Returns:
DataFrame with binary activity labels, pchembl_relation column, and conflict flags
"""
if conflict_resolution is not None and conflict_resolution not in VALID_CONFLICT_STRATEGIES:
raise ValueError(
f"Unknown conflict resolution strategy: '{conflict_resolution}'. "
f"Valid options: {sorted(VALID_CONFLICT_STRATEGIES)}"
)
required_cols = [compound_id_col, target_id_col, value_column]
missing_cols = [col for col in required_cols if col not in df.columns]
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
df = df.copy()
if relation_col not in df.columns:
logger.warning(f"Column '{relation_col}' not found. Assuming all measurements have '=' relation.")
df[relation_col] = "="
pchembl_relation_col = "pchembl_relation"
df[pchembl_relation_col] = df[relation_col].apply(invert_relation_for_pchembl)
groupby_cols = [compound_id_col, target_id_col]
if not compare_across_mutants and "mutation" in df.columns:
groupby_cols.append("mutation")
df[output_binary_col] = np.nan
conflict_indices = []
for idx, row in df.iterrows():
value = row[value_column]
relation = row[relation_col]
if pd.isna(value):
continue
try:
df.loc[idx, output_binary_col] = _classify_by_relation(value, relation, threshold)
except ValueError as e:
logger.warning(f"{e} at index {idx}. Skipping binarization.")
for group_key, group_df in df.groupby(groupby_cols):
has_conflict, conflict_type = _detect_conflicts(group_df, output_binary_col)
if has_conflict:
conflict_indices.extend(group_df.index.tolist())
logger.debug(
f"Conflict ({conflict_type}) for {compound_id_col}={group_key[0]}, "
f"{target_id_col}={group_key[1]}"
)
df = _log_and_flag_conflicts(
df,
conflict_indices,
compound_id_col,
target_id_col,
pchembl_relation_col,
value_column,
output_binary_col,
conflict_resolution=conflict_resolution,
)
# Generate conflict details before resolution (needs all rows present)
conflict_details = None
if conflict_report_path and conflict_indices:
conflict_details = _generate_conflict_details(
df,
conflict_indices,
compound_id_col,
target_id_col,
pchembl_relation_col,
value_column,
output_binary_col,
threshold,
)
# Apply conflict resolution if requested
resolution_details = None
if conflict_resolution and conflict_indices:
indices_to_drop, resolution_details = _resolve_conflicts(
df,
conflict_indices,
conflict_resolution,
compound_id_col,
target_id_col,
relation_col,
output_binary_col,
groupby_cols,
value_column=value_column,
threshold=threshold,
)
if indices_to_drop:
df = df.drop(index=indices_to_drop).reset_index(drop=True)
logger.info(
f"Conflict resolution (strategy='{conflict_resolution}'): "
f"dropped {len(indices_to_drop)} rows"
)
# Deduplicate: merge compound-target groups into one row, filtering disagreeing measurements
base_value_col = value_column.rsplit("_", 1)[0] if "_" in value_column else value_column
if conflict_resolution and base_value_col in df.columns and base_value_col != value_column:
use_geometric = base_value_col == "pchembl_value"
rows_before = len(df)
df = _deduplicate_resolved_groups(
df, groupby_cols, base_value_col, relation_col,
output_binary_col, threshold, use_geometric=use_geometric,
)
# Regenerate pchembl_relation from the (now possibly pipe-separated) standard_relation
df[pchembl_relation_col] = df[relation_col].astype(str).apply(
lambda r: "|".join(invert_relation_for_pchembl(x) for x in r.split("|"))
)
rows_after = len(df)
if rows_before != rows_after:
logger.info(
f"Deduplication: merged {rows_before} rows into {rows_after} rows"
)
# Save conflict report after resolution (so stats reflect final state)
if conflict_report_path and conflict_details is not None:
n_active_report = int((df[output_binary_col].dropna() == 1).sum())
n_inactive_report = int((df[output_binary_col].dropna() == 0).sum())
mcc_report = _calculate_assay_compatibility_mcc(
df, compound_id_col, target_id_col, output_binary_col
)
save_conflict_report(
conflict_details,
conflict_report_path,
threshold,
total_rows=len(df),
active_count=n_active_report,
inactive_count=n_inactive_report,
mcc=mcc_report,
resolution_details=resolution_details,
conflict_resolution=conflict_resolution,
)
df[output_binary_col] = df[output_binary_col].astype("Int64")
# Calculate summary statistics
n_active = (df[output_binary_col] == 1).sum()
n_inactive = (df[output_binary_col] == 0).sum()
n_missing = df[output_binary_col].isna().sum()
n_total = len(df)
# Calculate conflict statistics
n_conflicting_measurements = len(conflict_indices)
n_rows_dropped = 0
if resolution_details:
n_rows_dropped = sum(d["rows_dropped"] for d in resolution_details)
n_conflicting_pairs = len(resolution_details)
elif conflict_indices:
remaining_conflict_idx = [i for i in conflict_indices if i in df.index]
if remaining_conflict_idx:
conflict_subset = df.loc[remaining_conflict_idx]
conflict_groupby_cols = [compound_id_col, target_id_col]
if not compare_across_mutants and "mutation" in conflict_subset.columns:
conflict_groupby_cols.append("mutation")
n_conflicting_pairs = conflict_subset.groupby(conflict_groupby_cols).ngroups
else:
n_conflicting_pairs = 0
else:
n_conflicting_pairs = 0
# Calculate MCC for assay compatibility (only for pairs with multiple measurements)
mcc = _calculate_assay_compatibility_mcc(df, compound_id_col, target_id_col, output_binary_col)
# Log comprehensive summary
if n_total > 0:
conflict_lines = (
f" Conflicting measurements detected: {n_conflicting_measurements:>6}\n"
f" Compound-target pairs affected: {n_conflicting_pairs:>6}\n"
)
if conflict_resolution:
conflict_lines += (
f" Resolution strategy: {conflict_resolution:>6}\n"
f" Rows dropped by resolution: {n_rows_dropped:>6}\n"
)
else:
conflict_lines += (
f" Rows at risk if conflicts dropped: {n_conflicting_measurements:>6} "
f"({n_conflicting_measurements/n_total*100:>5.1f}%)\n"
)
logger.info(
f"BINARIZATION SUMMARY\n"
f"Threshold: {threshold}\n"
f"\nBinary labels:\n"
f" Active (1): {n_active:>6} ({n_active/n_total*100:>5.1f}%)\n"
f" Inactive (0): {n_inactive:>6} ({n_inactive/n_total*100:>5.1f}%)\n"
f" Missing: {n_missing:>6} ({n_missing/n_total*100:>5.1f}%)\n"
f"\nConflict analysis:\n"
f"{conflict_lines}"
f"\nAssay compatibility:\n"
f" MCC (agreement between measurements): {mcc:>6.3f}\n"
)
return df