Created
March 14, 2024 22:21
-
-
Save bsweger/f2da4cf8a0e1adb45fca7cb5019fb807 to your computer and use it in GitHub Desktop.
generate variant hub sample output data
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import product | |
import pandas as pd | |
import numpy as np | |
def make_sample( | |
n_samples: int = 2, | |
n_horizons: int = 3, | |
n_variants: int = 3, | |
n_locations: int = 2, | |
samples_joint_across: list[str] = None | |
) -> pd.DataFrame: | |
samples_joint_across = [] if not samples_joint_across else samples_joint_across | |
if len(samples_joint_across) > 3: | |
raise ValueError('Too many samples_joint_across!') | |
if len(samples_joint_across) != len(set(samples_joint_across)): | |
raise ValueError('Duplicate samples_joint_across!') | |
for item in samples_joint_across: | |
if item not in ['horizon', 'variant', 'location']: | |
raise ValueError('Sample joint across must be horizon, variant, or location') | |
# assume days for this exercise | |
temp_scale = 'd' | |
samples = np.arange(n_samples).astype('str') | |
horizons = np.arange(-1, n_horizons-1).astype('str') | |
variants = np.arange(0, n_variants).astype('str') | |
locations = np.arange(0, n_locations).astype('str') | |
df = pd.DataFrame.from_records( | |
list(product( | |
samples, | |
horizons, | |
locations, | |
variants, | |
)), | |
columns = ['sample', 'horizon', 'location', 'variant'], | |
) | |
if len(samples_joint_across) == 0: | |
df.sort_values(['sample', 'horizon', 'location', 'variant'], inplace=True) | |
df['traj'] = range(1, len(df) + 1) | |
else: | |
df.sort_values(['sample'] + samples_joint_across, inplace=True) | |
df['traj'] = df.groupby(samples_joint_across).cumcount() | |
# Generate the output_type_id | |
df['output_type_id'] = 'S' + df['traj'].astype('str') | |
for item in samples_joint_across: | |
df['output_type_id'] = df['output_type_id'] + '_' + item.upper()[0] + df[item] | |
# Add the constant columns | |
df['output_type'] = 'sample' | |
df['nowcast_date'] = pd.to_datetime('2024-01-26') | |
df['horizon'] = df['horizon'].astype('int') # horizon needs to be int fot date math | |
df['target_date'] = df['nowcast_date'] + pd.to_timedelta(df['horizon'], temp_scale) | |
# Deal with values later | |
# df['value'] = np.random.standard_normal(df.shape[0]) # current state 180 | |
return df | |
df_none = make_sample() | |
print(df_none.shape) | |
print(df_none.sort_values(['sample', 'traj'])) | |
print() | |
df_none.to_csv('sja_none.csv', index=False) | |
sja = ['horizon'] | |
df_h = make_sample(samples_joint_across=sja) | |
print(sja) | |
print(df_h.shape) | |
print(df_h.sort_values(['sample', 'traj'])) | |
print() | |
df_h.to_csv('sja_horizon.csv', index=False) | |
sja = ['horizon', 'variant'] | |
df_hv = make_sample(samples_joint_across=sja) | |
print(sja) | |
print(df_hv.shape) | |
print(df_hv.sort_values(['sample', 'traj'])) | |
print() | |
df_hv.to_csv('sja_horizon_variant.csv', index=False) | |
sja = ['horizon', 'variant', 'location'] | |
df_hvl = make_sample(samples_joint_across=sja) | |
print(sja) | |
print(df_hvl.shape) | |
print(df_hvl.sort_values(['sample', 'traj'])) | |
print() | |
df_hvl.to_csv('sja_horizon_variant_location.csv', index=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment