-
subtract() is row-based and requires exact row match
-
If df has duplicate rows, subtract() doesn't guarantee it removes just one instance.
# Filter 20% of the data for the holdout group
df_holdout = df.sample(fraction=0.2, seed=42)
df_holdout.display()
logger.info(f"Total holdout loyalty member id: {df_holdout.count():,}") #263,232
# Remove the holdout group from the original DataFrame to get the remaining 80%
df_target = df.subtract(df_holdout)
df_target.display()
logger.info(f"Total remaining loyalty member id: {df_target.count():,}") #1,051,368
# Filter 20% of the data for the holdout group
df_holdout = df.sample(fraction=0.2, seed=42)
df_holdout.display()
logger.info(f"Total holdout loyalty member id: {df_holdout.count():,}") #263,232
# Remove the holdout group from the original DataFrame to get the remaining 80%
df_target = df.join(df_holdout, on="BR_ID", how="leftanti")
df_target.display()
logger.info(f"Total remaining loyalty member id: {df_target.count():,}") #1,051,368
df_target, df_holdout = df.randomSplit([0.8, 0.2], seed=42)
Root Cause | Does sample() cause it? |
Fix |
---|---|---|
Duplicate rows | ✅ Yes | Use randomSplit() or add row IDs |
subtract() not safe |
✅ Yes | Use join(..., how='left_anti') |
Overlap in splits | ✅ Possible | Prefer randomSplit() |