Source code for Capricho.core.stats_make

"""Module containing helper functions for processing repeated elements in a DataFrame"""

from typing import List

import numpy as np
import pandas as pd
from chemFilters.chem.standardizers import ChemStandardizer

from ..logger import logger
from .default_fields import multiple_value_cols
from .pandas_helper import aggr_val_series, apply_func_grpd, assign_stats, format_value


[docs] def repeated_indices_from_IDs_df(df: pd.DataFrame, columns: list) -> List[List[int]]: """Find repeated indices for given columns in a DataFrame. Args: df (pd.DataFrame): The DataFrame to search for repeats. columns (list): List of column names with external (non-chembl) IDs to identify repeats across. E.g.: ["JUMP_ID", "target_chembl_id"] Returns: list: A list of lists containing indices of repeated rows based on specified columns. """ # Validate input if not all(col in df for col in columns): raise ValueError("One or more specified columns do not exist in the DataFrame.") # Concatenate values of specified columns into a single series concatenated_series = df[columns].astype(str).agg("-".join, axis=1) # Find duplicates using numpy operations def find_duplicate_index(series: pd.Series) -> List[List[int]]: arr = series.to_numpy() sidx = np.lexsort(arr.reshape(1, -1)) sorted_series = series.iloc[sidx] sorted_arr = sorted_series.to_numpy() # Identify duplicates duplicates_mask = np.concatenate(([False], sorted_arr[1:] == sorted_arr[:-1], [False])) idx = np.flatnonzero(duplicates_mask[1:] != duplicates_mask[:-1]) # Extract original indices for duplicates sorted_indices = sorted_series.index.tolist() return [sorted_indices[i:j] for i, j in zip(idx[::2], idx[1::2] + 1)] final_repeat_idxs = find_duplicate_index(concatenated_series) return final_repeat_idxs
[docs] def repeated_indices_from_array_series(series: pd.Series) -> List[List[int]]: """Function to find repeated arrays from a list of arrays""" def find_duplicate_index(series: np.array) -> List[List[int]]: """Group indices of duplicate rows From https://stackoverflow.com/a/46629623 Args: arr (numpy array): array of fingerprints Returns: list[list[int]]: list of lists of indices of duplicate rows """ # Sort by rows arr = np.vstack(series) sidx = np.lexsort(arr.T) b = arr[sidx] # Get unique row mask m = np.concatenate(([False], (b[1:] == b[:-1]).all(1), [False])) # Get start and stop indices for each group of duplicates idx = np.flatnonzero(m[1:] != m[:-1]) # Get sorted indices sort_idxs = series.index[sidx].tolist() # Return list of lists of indices of duplicate rows return [sort_idxs[i:j] for i, j in zip(idx[::2], idx[1::2] + 1)] final_repeat_idxs = find_duplicate_index(series) return final_repeat_idxs
[docs] def process_repeat_mols( df: pd.DataFrame, repeat_element_idxs: List[List[int]], solve_strat: str = "keep", multiple_value_cols: List[str] = multiple_value_cols, extra_id_cols: List[str] = [], extra_multival_cols: List[str] = [], chirality: bool = False, aggregate_mutants: bool = False, value_col: str = "pchembl_value", ) -> pd.DataFrame: """Process the dataframe according to repeated elements identified with the function `find_repeated_arr_from_series`. The standard criteria here will be that molecules with the same Fingerprint representation will be treated as a single entity, and will have their values aggregated. Upon aggregation, if the min & max values differ 1 or more log units, then those samples will be remioved from the dataset. Otherwise, values will be aggregated and a new column will be assigned, called `might_rancemic`. This column will be a boolean, indicating whether the molecule might be rancemic or not. Args: df: dataframe with the bioactivity data repeat_element_idxs: list of indices of repeated elements in the dataframe. solve_strat: strategy to solve the repeated elements. If 'drop', then both the points within >= 1 log unit difference will be dropped. If 'keep', then no values will be dropped. extra_id_cols: list of extra identification columns you might have for your own compounds that you'd like to use to avoid mixing data & to keep in the final dataframe. Defaults to []. extra_multival_cols: list of extra columns that you'd like to keep as aggregated values in the final dataframe. Caveat: these columns will be displayes as (str) separated by `|` (pipe) in the final dataframe. Defaults to []. chirality: boolean flag to indicate whether the fingerprints used to check for identical compounds is chirality-sensitive or not. Defaults to False Returns: df: dataframe with the repeated elements processed. """ df = df.copy() repeat_mapping = {} for idx in range(len(repeat_element_idxs)): for i in repeat_element_idxs[idx]: repeat_mapping[i] = idx df = df.assign(repeat_mapping=lambda x: x.index.map(repeat_mapping)) repeat_subset = df.query("~repeat_mapping.isna()").assign( **{value_col: lambda df, vc=value_col: df[vc].apply(lambda val: format_value(val))} ) if not repeat_subset.empty: numeric_activity = ( # concatenate grouped values and convert to numeric arrays repeat_subset.groupby(["repeat_mapping"])[value_col] .apply(lambda x: "|".join(x)) .str.split("|") .apply(lambda x: np.array(x).astype(float)) ) else: logger.info( "Multiple readouts on the same compound not found within the dataset. Statistics " "columns (counts, mean, median) will be calculated solely for the sake of consistency." ) numeric_activity = ( repeat_subset.groupby(["repeat_mapping"])[value_col] .apply(lambda x: "|".join(x)) .apply(lambda x: np.array(x).astype(float)) ) max_series = numeric_activity.apply(lambda x: np.max(x)) min_series = numeric_activity.apply(lambda x: np.min(x)) activity_divergence_series = max_series - min_series # Will drop the repeats with more than 1 log unit difference high_diff_repeats = np.where(activity_divergence_series >= 1)[0] points_dropped = len(repeat_subset["repeat_mapping"].isin(high_diff_repeats)) logger.info(f"Found {len(high_diff_repeats)} repeats with more than 1 log unit difference.") logger.info(f"Maximum difference between min & max values: {np.max(activity_divergence_series)}") if solve_strat == "drop": logger.info(f"{points_dropped} points will be removed from the dataset") id_cols = [*extra_id_cols, "repeat_mapping", "target_chembl_id", "standard_relation"] if value_col == "standard_type": id_cols.append("standard_units") # Build multivalue columns: include value_col and all multivalue columns except # pchembl_value when it's not the value_col (avoid aggregating unnecessary NaN values) effective_multival_cols = [] for col in multiple_value_cols: if col == "pchembl_value" and value_col != "pchembl_value": # Skip pchembl_value when aggregating on other columns (e.g., standard_value) pass elif col == value_col: # Always include the value_col for aggregation effective_multival_cols.append(col) elif col not in id_cols and col in df.columns: # Include other multivalue columns effective_multival_cols.append(col) multival_cols = [*effective_multival_cols, *extra_multival_cols] if aggregate_mutants: multival_cols = [*multival_cols, "mutation"] else: id_cols = [*id_cols, "mutation"] repeat_subset.loc[:, multival_cols].replace({None: "None"}, inplace=True) if pd.__version__ >= "1.5.0": grouped = repeat_subset.groupby(id_cols, group_keys=True) else: grouped = repeat_subset.groupby(id_cols) try: updated_vals = apply_func_grpd(grouped, aggr_val_series, id_cols, *multival_cols) except TypeError as e: logger.error( f"Error while aggregating values for columns {multival_cols}. " f"Data shouldn't contain NaN values, check the columns: {df.isna().sum()}" ) raise e updated_df = assign_stats( updated_vals, value_col=value_col, use_geometric=(value_col == "pchembl_value") ).merge(df, on=id_cols, how="left") rename_cols = {c: c.rstrip("_x") for c in updated_df.columns if c.endswith("_x")} todrop_cols = [c for c in updated_df.columns if c.endswith("_y")] updated_df = updated_df.drop(columns=todrop_cols).rename(columns=rename_cols) todrop_processed = updated_df["repeat_mapping"].isin(high_diff_repeats).index smiles_canonizer = ChemStandardizer( method="canon", from_smi=True, n_jobs=8, progress=True, isomeric=chirality, chunk_size=None ) if solve_strat == "drop": updated_df = updated_df.drop(index=todrop_processed) # Convert multival_cols (except value_col) to strings for non-aggregated rows non_aggregated_df = df.drop(index=repeat_subset.index).assign( might_rancemic=lambda x: [False] * len(x), ) for col in multival_cols: if col in non_aggregated_df.columns and col != value_col: non_aggregated_df[col] = non_aggregated_df[col].apply(format_value) df = pd.concat( [ non_aggregated_df, updated_df.assign(might_rancemic=lambda x: [True if not chirality else False] * len(x)), ], ignore_index=True, ) # the `smiles` column will be the final smiles column; to be used for modeling smiles = df["standard_smiles"].apply( lambda smi: smi if pd.isna(smi) or "|" not in smi else smi.split("|")[0] ) logger.info("Canonicalizing smiles...") df = df.assign(smiles=smiles_canonizer(smiles)) stats_cols = [f"{value_col}{suffix}" for suffix in ["_mean", "_std", "_median", "_counts"]] final_cols = [*id_cols, "smiles", *multival_cols, "might_rancemic", *stats_cols] final_cols.pop(final_cols.index("repeat_mapping")) # remove repeat_mapping from final_cols df = ( df[final_cols] .rename(columns={"standard_smiles": "processed_smiles"}) .reset_index(drop=True) .drop_duplicates() ) logger.info(f"Final number of points: {len(df)}") # Also add the single-read points to the mean / median / counts values with pd.option_context("future.no_silent_downcasting", True): df[f"{value_col}_median"] = df[f"{value_col}_median"].fillna(df[value_col]).infer_objects(copy=False) df[f"{value_col}_mean"] = df[f"{value_col}_mean"].fillna(df[value_col]).infer_objects(copy=False) df[f"{value_col}_counts"] = df[f"{value_col}_counts"].fillna(1).infer_objects(copy=False) df[value_col] = df[value_col].apply(format_value) # convert to str for consistency return df