Created
September 27, 2021 11:57
-
-
Save erap129/17bcea0616603c7f9473575f38ccd92e to your computer and use it in GitHub Desktop.
NASA RUL project - windowing the data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
WINDOW_SIZE = 20 | |
def get_windowed_dataframes(df): | |
df_groups = df.sort_values(['unit_number', 'time']).groupby('unit_number') | |
all_rollings = [] | |
for _, group_df in df_groups: | |
group_df_rolling = group_df.rolling(window=WINDOW_SIZE) | |
all_rollings.extend([wnd for wnd in group_df_rolling if len(wnd) == WINDOW_SIZE]) | |
return all_rollings | |
def get_windowed_xy(all_rollings): | |
all_rollings_X = np.array([wnd[[x for x in wnd.columns if 'sensor_' in x]].values for wnd in all_rollings]) | |
all_rollings_y = np.array([wnd['RUL'].iloc[-1] for wnd in all_rollings]).clip(max=125) | |
return all_rollings_X, all_rollings_y | |
all_rollings = get_windowed_dataframes(train_df) | |
X_train_rolling, y_train_rolling = get_windowed_xy(all_rollings) | |
X_test_rolling = np.array(test_df.sort_values(['unit_number', 'time']).groupby('unit_number').\ | |
apply(lambda group_df: group_df[[x for x in test_df.columns if 'sensor_' in x]].\ | |
iloc[-WINDOW_SIZE:].values).tolist()) | |
for i, rolling_df in enumerate(all_rollings): | |
rolling_df['instance_id'] = i | |
rolling_df.drop(columns=['unit_number'], inplace=True) | |
all_rollings_df = pd.concat(all_rollings) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment