Skip to content

Instantly share code, notes, and snippets.

@datavudeja
Forked from gvspraveen/feature_aggregations.py
Created October 6, 2025 13:18
Show Gist options
  • Save datavudeja/1eff40cb3e2b44de13728b04954b5987 to your computer and use it in GitHub Desktop.
Save datavudeja/1eff40cb3e2b44de13728b04954b5987 to your computer and use it in GitHub Desktop.
Dataset stats using aggregators
import pandas as pd
import ray
from ray.data import Dataset
from ray.data.context import DataContext, ShuffleStrategy
from typing import List
import time
from ray.data.aggregate import Count, Mean, Min, Max, Quantile, Std, Unique, AggregateFnV2
from ray.data.block import BlockAccessor, Block
from typing import List, Tuple, Optional
from ray.data import Dataset
import pyarrow as pa
import pyarrow.compute as pc
class MissingValuePercentage(AggregateFnV2):
"""Calculates the percentage of null values in a column."""
def __init__(
self,
on: str,
alias_name: Optional[str] = None,
):
# Initialize with a list accumulator [null_count, total_count]
super().__init__(
alias_name if alias_name else f"missing_pct({str(on)})",
on=on,
ignore_nulls=False, # Include nulls for this calculation
zero_factory=lambda: [0, 0], # Our AggType is a simple list
)
def aggregate_block(self, block: Block) -> List[int]:
# Use BlockAccessor to work with the block
block_acc = BlockAccessor.for_block(block)
# Convert to Arrow for efficient operations
table = block_acc.to_arrow()
column = table.column(self._target_col_name)
# Use PyArrow compute for vectorized counting
total_count = len(column)
null_count = pc.sum(pc.is_null(column, nan_is_null=True).cast("int32")).as_py()
# null_count = table[self._target_col_name].null_count
# Return our accumulator
return [null_count, total_count]
def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]:
# Merge two accumulators by summing their components
return [
current_accumulator[0] + new[0], # Sum null counts
current_accumulator[1] + new[1], # Sum total counts
]
def _finalize(self, accumulator: List[int]) -> Optional[float]:
# Calculate the final percentage
if accumulator[1] == 0:
return None
return (accumulator[0] / accumulator[1]) * 100.0
class ZeroPercentage(AggregateFnV2):
"""Calculates the percentage of zero values in a numeric column."""
def __init__(
self,
on: str,
ignore_nulls: bool = True,
alias_name: Optional[str] = None,
):
# Initialize with a list accumulator [zero_count, non_null_count]
super().__init__(
alias_name if alias_name else f"zero_pct({str(on)})",
on=on,
ignore_nulls=ignore_nulls,
zero_factory=lambda: [0, 0],
)
def aggregate_block(self, block: Block) -> List[int]:
# Get BlockAccessor
block_acc = BlockAccessor.for_block(block)
# Convert to Arrow
table = block_acc.to_arrow()
column = table.column(self._target_col_name)
non_null_count = 0
if self._ignore_nulls:
# Use PyArrow compute to count non-null values
# First create a boolean mask for non-null values
non_null_mask = pc.is_valid(column)
# Sum the boolean mask to get count of True values (non-nulls)
non_null_count = pc.sum(non_null_mask).as_py()
else:
non_null_count = pc.count(column).as_py()
if non_null_count == 0:
return [0, 0]
# Use PyArrow compute to count zeros
# First create a boolean mask for zero values
zero_mask = pc.equal(column, 0)
# Sum the boolean mask to get count of True values (zeros)
zero_count = pc.sum(zero_mask).as_py() or 0
return [zero_count, non_null_count]
def combine(self, current_accumulator: List[int], new: List[int]) -> List[int]:
return [
current_accumulator[0] + new[0], # Sum zero counts
current_accumulator[1] + new[1], # Sum non-null counts
]
def _finalize(self, accumulator: List[int]) -> Optional[float]:
if accumulator[1] == 0:
return None
return (accumulator[0] / accumulator[1]) * 100.0
def get_numerical_aggregators(column: str) -> List[AggregateFnV2]:
"""Generate default metrics for numerical columns.
This function returns a list of aggregators that compute the following metrics:
- count
- mean
- min
- max
- median (using quantile with q=0.5)
- Std
- MissingValuePercentage
Args:
column: The name of the numerical column to compute metrics for.
Returns:
A list of AggregateFnV2 instances that can be used with Dataset.aggregate()
"""
return [
Count(on=column, ignore_nulls=True),
Mean(on=column, ignore_nulls=True),
Min(on=column, ignore_nulls=True),
Max(on=column, ignore_nulls=True),
Quantile(on=column, q=0.5, ignore_nulls=True, alias_name=f"median({column})"),
Std(on=column, ignore_nulls=True),
MissingValuePercentage(on=column),
ZeroPercentage(on=column, ignore_nulls=True),
]
def get_categorical_aggregators(column: str) -> List[AggregateFnV2]:
"""Generate default metrics for string columns.
This function returns a list of aggregators that compute the following metrics:
- count
- MissingValuePercentage
Args:
column: The name of the numerical column to compute metrics for.
Returns:
A list of AggregateFnV2 instances that can be used with Dataset.aggregate()
"""
return [
Count(on=column, ignore_nulls=True),
MissingValuePercentage(on=column),
Unique(on=column),
]
def get_feature_aggregators_for_dataset(dataset: Dataset) -> Tuple[List[str], List[str], List[AggregateFnV2]]:
"""Generate aggregators for all columns in a dataset.
Args:
dataset: A Ray Dataset instance
Returns:
A typle of
- List of numerical col names
- List of str col names
- A list of AggregateFnV2 instances that can be used with Dataset.aggregate()
"""
schema = dataset.schema()
if not schema:
raise ValueError("Dataset must have a schema to determine numerical columns")
numerical_columns = []
str_columns = []
all_aggs = []
fields_and_types = zip(schema.names, schema.types)
for name, ftype in fields_and_types:
# Check if the field type is numerical using PyArrow's type system
if (pa.types.is_integer(ftype) or
pa.types.is_floating(ftype) or
pa.types.is_decimal(ftype)):
numerical_columns.append(name)
all_aggs.extend(get_numerical_aggregators(name))
elif pa.types.is_string(ftype):
str_columns.append(name)
all_aggs.extend(get_categorical_aggregators(name))
else:
print(f"Dropping field {name} as its type {ftype} is not numerical")
return (numerical_columns, str_columns, all_aggs)
def convert_metrics_to_nested_dict(metrics_result: dict, columns_to_consider: List[str]) -> dict:
"""Convert flat metrics results into a nested dictionary structure.
This function takes the flat dictionary returned by Dataset.aggregate() and
converts it into a nested dictionary where:
- First level keys are column names
- Second level keys are metric names (count, mean, min, max, median)
Args:
metrics_result: Dictionary returned by Dataset.aggregate()
columns_to_consider: Columns which need to be considered
Returns:
A nested dictionary with column names as first level keys and metrics as second level keys
Example:
Input: {
"count(col1)": 100,
"mean(col1)": 5.5,
"min(col1)": 1,
"max(col1)": 10,
"median(col1)": 5,
"count(col2)": 100,
"mean(col2)": 7.5,
...
}
Output: {
"col1": {
"count": 100,
"mean": 5.5,
"min": 1,
"max": 10,
"median": 5
},
"col2": {
"count": 100,
"mean": 7.5,
...
}
}
"""
nested_dict = {}
for key, value in metrics_result.items():
# Extract metric name and column name from the key
# Keys are in format: "metric_name(column_name)"
metric_name = key.split("(")[0]
column_name = key.split("(")[1].rstrip(")")
if column_name not in columns_to_consider:
continue
# Create nested structure
if column_name not in nested_dict:
nested_dict[column_name] = {}
nested_dict[column_name][metric_name] = value
return nested_dict
def convert_metrics_to_dataframe(metrics_result: dict, columns_to_consider: List[str]) -> "pd.DataFrame":
"""Convert nested metrics dictionary to pandas DataFrame.
This function takes the flat dictionary returned by Dataset.aggregate()
and converts it into a pandas DataFrame where:
- Index: Column names
- Columns: Metric names (count, mean, min, max, median etc)
Args:
metrics_result: Dictionary returned by Dataset.aggregate()
columns_to_consider: Columns which need to be considered
Returns:
A pandas DataFrame with column names as index and metrics as columns
Example:
Input: {
"count(col1)": 100,
"mean(col1)": 5.5,
"min(col1)": 1,
"max(col1)": 10,
"median(col1)": 5,
"count(col2)": 100,
"mean(col2)": 7.5,
...
}
Output:
DataFrame with:
- Index: col1, col2
- Columns: count, mean, min, max, median
"""
nested_metrics = convert_metrics_to_nested_dict(metrics_result, columns_to_consider)
# Convert nested dict to DataFrame
df = pd.DataFrame.from_dict(nested_metrics, orient='index').T
# # Sort columns in a logical order
# metric_order = ['count', 'mean', 'min', 'max', 'median']
# df = df.reindex(columns=metric_order)
return df
def run():
DataContext.get_current().shuffle_strategy = ShuffleStrategy.HASH_SHUFFLE
dataset_path = "<Your path>"
ds = ray.data.read_parquet(dataset_path)
# We want to run all aggregators at once (for both numerical and string columns)
numerical_cols, categorical_cols, all_aggs = get_feature_aggregators_for_dataset(ds)
result = ds.aggregate(
*all_aggs
)
# print(result)
# Create separate table for numerical cols and separate table for string cols
numerical_df = convert_metrics_to_dataframe(result, numerical_cols)
numerical_df.to_csv('result/numerical_features.csv', index=True, header=True)
string_features = convert_metrics_to_dataframe(result, categorical_cols)
string_features.to_csv('result/categorical_features.csv', index=True, header=True)
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment