Last active
October 22, 2025 02:04
-
-
Save gvspraveen/f907b83f54423456c1c32772ae623733 to your computer and use it in GitHub Desktop.
Dataset stats using aggregators
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pandas as pd | |
| import ray | |
| from ray.data import Dataset | |
| from ray.data.context import DataContext, ShuffleStrategy | |
| from typing import List | |
| import time | |
| from ray.data.aggregate import Count, Mean, Min, Max, Quantile, Std, Unique, AggregateFnV2 | |
| from ray.data.block import BlockAccessor, Block | |
| from typing import List, Tuple, Optional | |
| from ray.data import Dataset | |
| import pyarrow as pa | |
| import pyarrow.compute as pc | |
| class MissingValuePercentage(AggregateFnV2): | |
| """Calculates the percentage of null values in a column.""" | |
| def __init__( | |
| self, | |
| on: str, | |
| alias_name: Optional[str] = None, | |
| ): | |
| # Initialize with a list accumulator [null_count, total_count] | |
| super().__init__( | |
| alias_name if alias_name else f"missing_pct({str(on)})", | |
| on=on, | |
| ignore_nulls=False, # Include nulls for this calculation | |
| zero_factory=lambda: [0, 0], # Our AggType is a simple list | |
| ) | |
| def aggregate_block(self, block: Block) -> List[int]: | |
| # Use BlockAccessor to work with the block | |
| block_acc = BlockAccessor.for_block(block) | |
| # Convert to Arrow for efficient operations | |
| table = block_acc.to_arrow() | |
| column = table.column(self._target_col_name) | |
| # Use PyArrow compute for vectorized counting | |
| total_count = len(column) | |
| null_count = pc.sum(pc.is_null(column, nan_is_null=True).cast("int32")).as_py() | |
| # null_count = table[self._target_col_name].null_count | |
| # Return our accumulator | |
| return [null_count, total_count] | |
| def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]: | |
| # Merge two accumulators by summing their components | |
| return [ | |
| current_accumulator[0] + new[0], # Sum null counts | |
| current_accumulator[1] + new[1], # Sum total counts | |
| ] | |
| def _finalize(self, accumulator: List[int]) -> Optional[float]: | |
| # Calculate the final percentage | |
| if accumulator[1] == 0: | |
| return None | |
| return (accumulator[0] / accumulator[1]) * 100.0 | |
| class ZeroPercentage(AggregateFnV2): | |
| """Calculates the percentage of zero values in a numeric column.""" | |
| def __init__( | |
| self, | |
| on: str, | |
| ignore_nulls: bool = True, | |
| alias_name: Optional[str] = None, | |
| ): | |
| # Initialize with a list accumulator [zero_count, non_null_count] | |
| super().__init__( | |
| alias_name if alias_name else f"zero_pct({str(on)})", | |
| on=on, | |
| ignore_nulls=ignore_nulls, | |
| zero_factory=lambda: [0, 0], | |
| ) | |
| def aggregate_block(self, block: Block) -> List[int]: | |
| # Get BlockAccessor | |
| block_acc = BlockAccessor.for_block(block) | |
| # Convert to Arrow | |
| table = block_acc.to_arrow() | |
| column = table.column(self._target_col_name) | |
| non_null_count = 0 | |
| if self._ignore_nulls: | |
| # Use PyArrow compute to count non-null values | |
| # First create a boolean mask for non-null values | |
| non_null_mask = pc.is_valid(column) | |
| # Sum the boolean mask to get count of True values (non-nulls) | |
| non_null_count = pc.sum(non_null_mask).as_py() | |
| else: | |
| non_null_count = pc.count(column).as_py() | |
| if non_null_count == 0: | |
| return [0, 0] | |
| # Use PyArrow compute to count zeros | |
| # First create a boolean mask for zero values | |
| zero_mask = pc.equal(column, 0) | |
| # Sum the boolean mask to get count of True values (zeros) | |
| zero_count = pc.sum(zero_mask).as_py() or 0 | |
| return [zero_count, non_null_count] | |
| def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]: | |
| return [ | |
| current_accumulator[0] + new[0], # Sum zero counts | |
| current_accumulator[1] + new[1], # Sum non-null counts | |
| ] | |
| def _finalize(self, accumulator: List[int]) -> Optional[float]: | |
| if accumulator[1] == 0: | |
| return None | |
| return (accumulator[0] / accumulator[1]) * 100.0 | |
| def get_numerical_aggregators(column: str) -> List[AggregateFnV2]: | |
| """Generate default metrics for numerical columns. | |
| This function returns a list of aggregators that compute the following metrics: | |
| - count | |
| - mean | |
| - min | |
| - max | |
| - median (using quantile with q=0.5) | |
| - Std | |
| - MissingValuePercentage | |
| Args: | |
| column: The name of the numerical column to compute metrics for. | |
| Returns: | |
| A list of AggregateFnV2 instances that can be used with Dataset.aggregate() | |
| """ | |
| return [ | |
| Count(on=column, ignore_nulls=True), | |
| Mean(on=column, ignore_nulls=True), | |
| Min(on=column, ignore_nulls=True), | |
| Max(on=column, ignore_nulls=True), | |
| Quantile(on=column, q=0.5, ignore_nulls=True, alias_name=f"median({column})"), | |
| Std(on=column, ignore_nulls=True), | |
| MissingValuePercentage(on=column), | |
| ZeroPercentage(on=column, ignore_nulls=True), | |
| ] | |
| def get_categorical_aggregators(column: str) -> List[AggregateFnV2]: | |
| """Generate default metrics for string columns. | |
| This function returns a list of aggregators that compute the following metrics: | |
| - count | |
| - MissingValuePercentage | |
| Args: | |
| column: The name of the numerical column to compute metrics for. | |
| Returns: | |
| A list of AggregateFnV2 instances that can be used with Dataset.aggregate() | |
| """ | |
| return [ | |
| Count(on=column, ignore_nulls=True), | |
| MissingValuePercentage(on=column), | |
| Unique(on=column), | |
| ] | |
| def get_feature_aggregators_for_dataset(dataset: Dataset) -> Tuple[List[str], List[str], List[AggregateFnV2]]: | |
| """Generate aggregators for all columns in a dataset. | |
| Args: | |
| dataset: A Ray Dataset instance | |
| Returns: | |
| A typle of | |
| - List of numerical col names | |
| - List of str col names | |
| - A list of AggregateFnV2 instances that can be used with Dataset.aggregate() | |
| """ | |
| schema = dataset.schema() | |
| if not schema: | |
| raise ValueError("Dataset must have a schema to determine numerical columns") | |
| numerical_columns = [] | |
| str_columns = [] | |
| all_aggs = [] | |
| fields_and_types = zip(schema.names, schema.types) | |
| for name, ftype in fields_and_types: | |
| # Check if the field type is numerical using PyArrow's type system | |
| if (pa.types.is_integer(ftype) or | |
| pa.types.is_floating(ftype) or | |
| pa.types.is_decimal(ftype)): | |
| numerical_columns.append(name) | |
| all_aggs.extend(get_numerical_aggregators(name)) | |
| elif pa.types.is_string(ftype): | |
| str_columns.append(name) | |
| all_aggs.extend(get_categorical_aggregators(name)) | |
| else: | |
| print(f"Dropping field {name} as its type {ftype} is not numerical") | |
| return (numerical_columns, str_columns, all_aggs) | |
| def convert_metrics_to_nested_dict(metrics_result: dict, columns_to_consider: List[str]) -> dict: | |
| """Convert flat metrics results into a nested dictionary structure. | |
| This function takes the flat dictionary returned by Dataset.aggregate() and | |
| converts it into a nested dictionary where: | |
| - First level keys are column names | |
| - Second level keys are metric names (count, mean, min, max, median) | |
| Args: | |
| metrics_result: Dictionary returned by Dataset.aggregate() | |
| columns_to_consider: Columns which need to be considered | |
| Returns: | |
| A nested dictionary with column names as first level keys and metrics as second level keys | |
| Example: | |
| Input: { | |
| "count(col1)": 100, | |
| "mean(col1)": 5.5, | |
| "min(col1)": 1, | |
| "max(col1)": 10, | |
| "median(col1)": 5, | |
| "count(col2)": 100, | |
| "mean(col2)": 7.5, | |
| ... | |
| } | |
| Output: { | |
| "col1": { | |
| "count": 100, | |
| "mean": 5.5, | |
| "min": 1, | |
| "max": 10, | |
| "median": 5 | |
| }, | |
| "col2": { | |
| "count": 100, | |
| "mean": 7.5, | |
| ... | |
| } | |
| } | |
| """ | |
| nested_dict = {} | |
| for key, value in metrics_result.items(): | |
| # Extract metric name and column name from the key | |
| # Keys are in format: "metric_name(column_name)" | |
| metric_name = key.split("(")[0] | |
| column_name = key.split("(")[1].rstrip(")") | |
| if column_name not in columns_to_consider: | |
| continue | |
| # Create nested structure | |
| if column_name not in nested_dict: | |
| nested_dict[column_name] = {} | |
| nested_dict[column_name][metric_name] = value | |
| return nested_dict | |
| def convert_metrics_to_dataframe(metrics_result: dict, columns_to_consider: List[str]) -> "pd.DataFrame": | |
| """Convert nested metrics dictionary to pandas DataFrame. | |
| This function takes the flat dictionary returned by Dataset.aggregate() | |
| and converts it into a pandas DataFrame where: | |
| - Index: Column names | |
| - Columns: Metric names (count, mean, min, max, median etc) | |
| Args: | |
| metrics_result: Dictionary returned by Dataset.aggregate() | |
| columns_to_consider: Columns which need to be considered | |
| Returns: | |
| A pandas DataFrame with column names as index and metrics as columns | |
| Example: | |
| Input: { | |
| "count(col1)": 100, | |
| "mean(col1)": 5.5, | |
| "min(col1)": 1, | |
| "max(col1)": 10, | |
| "median(col1)": 5, | |
| "count(col2)": 100, | |
| "mean(col2)": 7.5, | |
| ... | |
| } | |
| Output: | |
| DataFrame with: | |
| - Index: col1, col2 | |
| - Columns: count, mean, min, max, median | |
| """ | |
| nested_metrics = convert_metrics_to_nested_dict(metrics_result, columns_to_consider) | |
| # Convert nested dict to DataFrame | |
| df = pd.DataFrame.from_dict(nested_metrics, orient='index').T | |
| # # Sort columns in a logical order | |
| # metric_order = ['count', 'mean', 'min', 'max', 'median'] | |
| # df = df.reindex(columns=metric_order) | |
| return df | |
| def run(): | |
| DataContext.get_current().shuffle_strategy = ShuffleStrategy.HASH_SHUFFLE | |
| dataset_path = "<Your path>" | |
| ds = ray.data.read_parquet(dataset_path) | |
| # We want to run all aggregators at once (for both numerical and string columns) | |
| numerical_cols, categorical_cols, all_aggs = get_feature_aggregators_for_dataset(ds) | |
| result = ds.aggregate( | |
| *all_aggs | |
| ) | |
| # print(result) | |
| # Create separate table for numerical cols and separate table for string cols | |
| numerical_df = convert_metrics_to_dataframe(result, numerical_cols) | |
| numerical_df.to_csv('result/numerical_features.csv', index=True, header=True) | |
| string_features = convert_metrics_to_dataframe(result, categorical_cols) | |
| string_features.to_csv('result/categorical_features.csv', index=True, header=True) | |
| run() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment