-
-
Save datavudeja/8e6c4a2bbf4468933cf9753a5b34ad12 to your computer and use it in GitHub Desktop.
A simple human in the loop workflow for document matching using Python, R, and Gnome's gedit
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/bin/bash | |
| ## working dir at file location | |
| cd "$(dirname "$0")" | |
| ## train the model | |
| echo "[+] Fitting a glmnet model on labeled data" | |
| Rscript 'learning.R' | |
| echo "[+] Finished fitting the model" | |
| ## read in the ambiguous cases | |
| TO_RECODE=`cat ambiguous.txt` | |
| echo "[+] Recoding ambiguous" | |
| ## Run the re-code script | |
| python3 human_in_the_loop.py -c "$TO_RECODE" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import tempfile, shutil | |
| import sys | |
| import pandas as pd | |
| import os | |
| import subprocess | |
| ## constants | |
| ORIGINAL_DF = "/home/jr/Dropbox/Current projects/thesis_papers/transparency, media, and compliance with HR Rulings/ecthr_media&compliance/data/media_data/3_classify_ecthr_news/model_2/data/model_data.csv" | |
| RECODED_DF = "/home/jr/Dropbox/Current projects/thesis_papers/transparency, media, and compliance with HR Rulings/ecthr_media&compliance/data/media_data/3_classify_ecthr_news/model_2/data/model_data_recoded.csv" | |
| RULART_DYAD_DF = "/home/jr/Dropbox/Current projects/thesis_papers/transparency, media, and compliance with HR Rulings/ecthr_media&compliance/data/media_data/3_classify_ecthr_news/annotate_2/data/rulings_article_dyad_data_raw.csv.gz" | |
| ALREADY_REVIEWED_DF_PATH = os.path.join( | |
| os.path.abspath(os.path.dirname(__file__)), "already_reviewed.csv" | |
| ) | |
| ## colors for the console notifications | |
| _W = '\033[0m' # white (normal) | |
| _R = '\033[31m' # red | |
| _G = '\033[32m' # green | |
| _O = '\033[33m' # orange | |
| _B = '\033[34m' # blue | |
| _P = '\033[35m' # purple | |
| ## read in the rulart dyad df | |
| rulart_dyad_df = pd.read_csv(RULART_DYAD_DF)[ | |
| ["article_id", "case_id", "leading_paragraph_translated", "ruling_text"] | |
| ] | |
| def parse_args() -> dict: | |
| ## parse CLI args | |
| parser = argparse.ArgumentParser( | |
| prog="human_in_the_loop.py", | |
| description="Recode the ecthr_label of ambiguous cases.", | |
| ) | |
| # applicant match flag | |
| parser.add_argument( | |
| "-a", | |
| "--applicant-match", | |
| dest="applicant_match", | |
| action="store_true", | |
| default=False, | |
| help="Consider ambiguous if ecthr_label = 0 and applicant_match == True", | |
| ) | |
| # above average similarity flag | |
| parser.add_argument( | |
| "-s", | |
| "--above-average-similarity", | |
| dest="above_average_similarity", | |
| action="store_true", | |
| default=False, | |
| help="Consider ambiguous if similarity metrics are above the average score for when ecthr_label == '1'", | |
| ) | |
| # date_country match | |
| parser.add_argument( | |
| "-d", | |
| "--date-country-match", | |
| dest="date_country_match", | |
| action="store_true", | |
| default=False, | |
| help="Consider ambiguous if country matches, if -1 < date diff < 6, if match_ruling and match_ecthr, and match_responent", | |
| ) | |
| # user defined | |
| parser.add_argument( | |
| "-c", | |
| "--csv-dyad-id", | |
| dest="csv_dyad_id", | |
| nargs="?", | |
| type=str, | |
| default=None, | |
| help="User defined ambiguous cases by passing a comma separated string containing the relevant dyad_ids", | |
| ) | |
| # parse. If no args display the "help menu" | |
| if len(sys.argv) == 1: | |
| parser.print_help() | |
| sys.exit(0) | |
| else: | |
| options = parser.parse_args() | |
| return options | |
| def read_current_labeled_dataset() -> pd.DataFrame: | |
| to_read = ORIGINAL_DF | |
| if os.path.isfile(RECODED_DF): | |
| to_read = RECODED_DF | |
| df = pd.read_csv(to_read) | |
| df["ecthr_label"] = df["ecthr_label"].astype(pd.StringDtype()) | |
| return df | |
| def applicant_match_test(df: pd.DataFrame) -> pd.DataFrame: | |
| filtered = df.query("applicant_has_match == True & ecthr_label == '0'") | |
| if not filtered.empty: | |
| return filtered | |
| def above_average_similarity_test(df: pd.DataFrame) -> pd.DataFrame: | |
| avg_jaccard_sim = df[df.ecthr_label == "1"]["jaccard_similarity_top_tfidf"].mean() | |
| avg_pairwise_sim = df[df.ecthr_label == "1"]["rulart_pairwise_similarity"].mean() | |
| avg_jsd_33 = df[df.ecthr_label == "1"]["jsd_score_k_33"].mean() # minimize! | |
| avg_jsd_70 = df[df.ecthr_label == "0"]["jsd_score_k_70"].mean() # idem | |
| # test | |
| filtered = df.query( | |
| f"ecthr_label == '0' & (jaccard_similarity_top_tfidf > {avg_jaccard_sim} & rulart_pairwise_similarity > {avg_pairwise_sim} & jsd_score_k_33 < {avg_jsd_33} & jsd_score_k_70 < {avg_jsd_70})" | |
| ) | |
| if not filtered.empty: | |
| return filtered | |
| def date_country_match_test(df: pd.DataFrame) -> pd.DataFrame: | |
| filtered = df.query( | |
| "ecthr_label == '0' & (match_ecthr == True & match_ruling == True & match_respondent == True & date_distance > -1 & date_distance < 6)" | |
| ) | |
| if not filtered.empty: | |
| return filtered | |
| def filter_ambiguous( | |
| df: pd.DataFrame, | |
| run_applicant_test: bool = False, | |
| run_above_average_similarity_test: bool = False, | |
| run_date_country_match_test: bool = False, | |
| user_defined_ids: str or None = None, | |
| ) -> pd.DataFrame: | |
| ambiguous_list = [] | |
| if run_applicant_test: | |
| app_test = applicant_match_test(df) | |
| if isinstance(app_test, pd.DataFrame): | |
| ambiguous_list.append(app_test) | |
| if run_above_average_similarity_test: | |
| above_test = above_average_similarity_test(df) | |
| if isinstance(above_test, pd.DataFrame): | |
| ambiguous_list.append(above_test) | |
| if run_date_country_match_test: | |
| dc_test = date_country_match_test(df) | |
| if isinstance(dc_test, pd.DataFrame): | |
| ambiguous_list.append(dc_test) | |
| if user_defined_ids: | |
| user_ids_list = [_.strip() for _ in user_defined_ids.split(",")] | |
| ambiguous_list.append(df[df.dyad_id.isin(user_ids_list)]) | |
| if ambiguous_list: | |
| return pd.concat(ambiguous_list).drop_duplicates( | |
| subset=["article_id", "case_id"] | |
| ) | |
| def write_to_tempfile(txt: str) -> str: | |
| my_tempfile = tempfile.NamedTemporaryFile(mode="w+t", delete=False) | |
| f = open(my_tempfile.name, "w") | |
| f.write(txt) | |
| f.close() | |
| my_tempfile.close() | |
| return my_tempfile.name | |
| def fetch_and_write_texts(article_id: str, case_id: str) -> dict: | |
| ## article | |
| article_text = rulart_dyad_df[ | |
| (rulart_dyad_df.article_id == article_id) & (rulart_dyad_df.case_id == case_id) | |
| ]["leading_paragraph_translated"].iloc[0] | |
| article_temp = write_to_tempfile(txt="\t\tARTICLE\n\n" + article_text) | |
| ## ruling | |
| ruling_text = rulart_dyad_df[ | |
| (rulart_dyad_df.article_id == article_id) & (rulart_dyad_df.case_id == case_id) | |
| ]["ruling_text"].iloc[0] | |
| ruling_temp = write_to_tempfile(txt="\t\tRULING\n\n" + ruling_text) | |
| return {"article_file": article_temp, "ruling_file": ruling_temp} | |
| def gedit_ambiguous_files(files_dict: dict) -> list: | |
| processes = [] | |
| for _file in files_dict.values(): | |
| process = subprocess.Popen(["gedit", _file]) | |
| processes.append(process) | |
| return processes | |
| def open_ambiguous_files(article_id: str, case_id: str) -> None: | |
| files = fetch_and_write_texts(article_id=article_id, case_id=case_id) | |
| processes = gedit_ambiguous_files(files_dict=files) | |
| return processes | |
| def close_gedit(processes: list) -> None: | |
| for process in processes: | |
| process.terminate() | |
| def write_recoded_df(df: pd.DataFrame) -> None: | |
| print(f"{_W}[+] Exporting recoded df to: {RECODED_DF}") | |
| df.to_csv(RECODED_DF, index=False) | |
| def get_or_make_reaviewed_df() -> pd.DataFrame: | |
| if not os.path.isfile(ALREADY_REVIEWED_DF_PATH): | |
| already_reviewed_df = pd.DataFrame(columns=["article_id", "case_id"]) | |
| already_reviewed_df.to_csv(ALREADY_REVIEWED_DF_PATH, index=False) | |
| else: | |
| already_reviewed_df = pd.read_csv(ALREADY_REVIEWED_DF_PATH, index_col=False) | |
| return already_reviewed_df | |
| def log_reviewed(case_id: str, article_id: str) -> None: | |
| ## read or create already reviewed df | |
| reviewed = get_or_make_reaviewed_df() | |
| reviewed.loc[len(reviewed.index)] = [article_id, case_id] | |
| reviewed.to_csv(ALREADY_REVIEWED_DF_PATH, index=False) | |
| def notify_already_reviewed() -> bool: | |
| review_again = "fooh" | |
| while review_again.lower() not in ["y", "n"]: | |
| review_again = input( | |
| f"\t{_R}[!] This dyad was already reviewed. Do you want to review it again?[y/n]: " | |
| ) | |
| if review_again.lower() == "y": | |
| go_on = True | |
| elif review_again.lower() == "n": | |
| go_on = False | |
| else: | |
| print(f"\t{_R}[!] Please reply with 'y' or 'n'.") | |
| return go_on | |
| def _review(article_id: str, case_id: str) -> str: | |
| ## open gedit windows with the dyad's texts | |
| _processes = open_ambiguous_files( | |
| article_id=article_id, case_id=case_id | |
| ) | |
| ## annotation | |
| decision = "fooh" | |
| while decision.lower() not in ["y", "n"]: | |
| decision = input(f"\t{_O}[+] Is this article related with this ruling? [y/n]:") | |
| if decision.lower() == "y": | |
| review_decision = "1" | |
| elif decision.lower() == "n": | |
| review_decision = "0" | |
| else: | |
| print(f"\t{_R}[!] Please reply with 'y' or 'n'.") | |
| # close gedit | |
| close_gedit(processes=_processes) | |
| return review_decision | |
| def review(article_id: str, case_id: str) -> bool or None: | |
| print(f"{_W}[+] Opening the ambiguous dyad via gedit: {article_id} --> {case_id}") | |
| review_decision = None | |
| # check if reviewed already | |
| reviewed = get_or_make_reaviewed_df() | |
| if not reviewed[ | |
| (reviewed.article_id == article_id) & (reviewed.case_id == case_id) | |
| ].empty: | |
| review_again = notify_already_reviewed() | |
| if review_again: | |
| review_decision = _review(case_id=case_id, article_id=article_id) | |
| else: | |
| review_decision = _review(case_id=case_id, article_id=article_id) | |
| log_reviewed(case_id=case_id, article_id=article_id) | |
| return review_decision | |
| def main() -> None: | |
| ## parse cli args | |
| cli_args = parse_args() | |
| ## load the labeled df | |
| current_labels = read_current_labeled_dataset() | |
| ## filter ambiguous cases | |
| print(f"{_W}[+] Looking for ambiguously labeled dyads...") | |
| ambiguous_df = filter_ambiguous( | |
| df=current_labels, | |
| run_applicant_test=cli_args.applicant_match, | |
| run_above_average_similarity_test=cli_args.above_average_similarity, | |
| run_date_country_match_test=cli_args.date_country_match, | |
| user_defined_ids=cli_args.csv_dyad_id, | |
| ).sample(frac=1) | |
| if isinstance(ambiguous_df, pd.DataFrame): | |
| print(f"{_W}[+] Found {ambiguous_df.shape[0]} ambiguous labels") | |
| try: | |
| for index, row in ambiguous_df.iterrows(): | |
| current_article_id = row["article_id"] | |
| current_case_id = row["case_id"] | |
| ## review | |
| review_decision = review( | |
| article_id=current_article_id, case_id=current_case_id | |
| ) | |
| ## compare with the previous label | |
| if review_decision: | |
| if row["ecthr_label"] == review_decision: | |
| print(f"\t\t{_W}[+] Same as before, skiping it") | |
| else: | |
| print(f"\t\t{_G}[+] Correcting the label to: {review_decision}") | |
| ## assign the new label to the dataset | |
| current_labels.loc[ | |
| (current_labels.article_id == current_article_id) | |
| & (current_labels.case_id == current_case_id), | |
| "ecthr_label", | |
| ] = review_decision | |
| except KeyboardInterrupt: | |
| print(f"\n{_R}[!] Exiting...") | |
| ## export | |
| write_recoded_df(df=current_labels) | |
| ## export | |
| write_recoded_df(df=current_labels) | |
| else: | |
| print(f"{_R}[!] Found no ambiguous labels") | |
| if __name__ == "__main__": | |
| main() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| require(glmnet) | |
| require(tidymodels) | |
| require(tidyverse) | |
| print(getwd()) | |
| ### Load the dataset | |
| LABELED_DF_PATH <- '/home/jr/Dropbox/Current projects/thesis_papers/transparency, media, and compliance with HR Rulings/ecthr_media&compliance/data/media_data/3_classify_ecthr_news/model_2/data/model_data_recoded.csv' | |
| UNLABELED_DF_PATH <- "/home/jr/Dropbox/Current projects/thesis_papers/transparency, media, and compliance with HR Rulings/ecthr_media&compliance/data/media_data/3_classify_ecthr_news/model_2/data/to_predict.csv" | |
| N_CASES <- 20 | |
| TO_RECODE_DOC <- '/home/jr/Dropbox/Current projects/thesis_papers/transparency, media, and compliance with HR Rulings/ecthr_media&compliance/data/media_data/3_classify_ecthr_news/model_2/scripts/human_in_the_loop/ambiguous.txt' | |
| #### Prepare the data | |
| #### ----------------------------------------------------------------------------------------------------------------------------------------- | |
| model_data <- read_csv(LABELED_DF_PATH) %>% | |
| select(-c(article_id, case_id, model_id)) %>% | |
| mutate(ecthr_label = as.factor(ecthr_label)) | |
| to_predict_data <- read_csv(UNLABELED_DF_PATH) %>% | |
| select(-c(article_id, case_id, model_id)) %>% | |
| mutate(ecthr_label = NA) | |
| ### data partition | |
| # set.seed(123) | |
| the_split <- initial_split(model_data, strata = "ecthr_label", prop = 0.7) | |
| train_data <- training(the_split) | |
| test_data <- testing(the_split) | |
| ### cv samples | |
| cv_samples <- rsample::vfold_cv(data = train_data, v = 5, repeats = 5, strata = "ecthr_label") | |
| cv_samples | |
| #### Prepare the model | |
| #### ----------------------------------------------------------------------------------------------------- | |
| lasso_spec <- logistic_reg(penalty = tune(), mixture = tune()) %>% | |
| set_mode("classification") %>% | |
| set_engine("glmnet") | |
| #### Prepare the data processing recipe | |
| #### ------------------------------------------------------------------------------------------------------- | |
| ## prep the recipe | |
| recip <- recipe(ecthr_label ~ .,data = train_data,) %>% | |
| update_role(dyad_id, new_role = "id variable") %>% ## removing several variables which reduced/neutral_to performance | |
| step_mutate(## discretize the date distance variable | |
| date_distance = date_distance %>% | |
| str_replace("-", "minus_") %>% | |
| as_factor() | |
| ) %>% | |
| step_dummy(date_distance) %>% ## date as dummy | |
| step_normalize(all_numeric(), -all_outcomes()) ## center and scale | |
| #step_smote(ecthr_label, neighbors = 27, seed = 1234) ## smote sampling, k = 27 | |
| recip | |
| #### FInalize the workflow | |
| #### ----------------------------------------------------------------------------------------------------------- | |
| ## Prep the workflow | |
| lasso_wf <- workflow() %>% | |
| add_recipe(recip) %>% | |
| add_model(lasso_spec) | |
| lasso_wf | |
| #### Tune | |
| #### ----------------------------------------------------------------------------------------------------------- | |
| ## run | |
| lasso_bayes_tune <- tune_bayes( | |
| lasso_wf, | |
| resamples = cv_samples, | |
| # Going with default range values for lambda | |
| param_info = parameters(penalty(), mixture()), | |
| # Generate five at semi-random to start | |
| initial = 25, | |
| iter = 100, | |
| # How to measure performance? cohen's kappa | |
| metrics = metric_set(kap, bal_accuracy, roc_auc, f_meas, ppv, npv, recall, precision), | |
| control = control_bayes(no_improve = 25, verbose = FALSE, save_pred = TRUE, seed = 1234) | |
| ) | |
| # notify | |
| system(paste("notify-send", "model tuned!", collapse = " ")) | |
| ## plot the metrics | |
| ## plot the metrics | |
| (p1 <- autoplot(lasso_bayes_tune) + | |
| theme_minimal() + | |
| ggtitle("wv-lasso model")) | |
| ggsave(p1, | |
| filename = "current_tunning_metrics.pdf", | |
| width = 8, | |
| height = 6, | |
| device = "pdf") | |
| ### select the best model based cohen's kappa | |
| best_lasso <- select_best(lasso_bayes_tune, metric = "kap") | |
| #### Fit the selected model | |
| #### ----------------------------------------------------------------- | |
| lasso_final <- lasso_wf %>% | |
| finalize_workflow(best_lasso) %>% | |
| fit(data = train_data) | |
| ## extract the coefficients | |
| tidy(extract_model(lasso_final)) | |
| ### extract predicted probs | |
| train_probs <- predict(lasso_final, type = "prob", new_data = train_data) %>% | |
| bind_cols(ecthr_label = train_data$ecthr_label) %>% | |
| bind_cols(dyad_id = train_data$dyad_id) %>% | |
| bind_cols(predict(lasso_final, new_data = train_data)) | |
| train_probs | |
| ## confusion matrix | |
| conf_mat(train_probs, ecthr_label, .pred_class) | |
| #### predict on the test set | |
| ### extract predicted probs | |
| test_probs <- predict(lasso_final, type = "prob", new_data = test_data) %>% | |
| bind_cols(ecthr_label = test_data$ecthr_label) %>% | |
| bind_cols(dyad_id = test_data$dyad_id) %>% | |
| bind_cols(predict(lasso_final, new_data = test_data)) | |
| test_probs | |
| ## confusion matrix | |
| conf_mat(test_probs, ecthr_label, .pred_class) | |
| ### predict on the unlabeled data | |
| ### extract predicted probs | |
| unlab_probs <- predict(lasso_final, type = "prob", new_data = to_predict_data) %>% | |
| bind_cols(dyad_id = to_predict_data$dyad_id) %>% | |
| bind_cols(predict(lasso_final, new_data = to_predict_data)) | |
| unlab_probs | |
| #### Prepate te data for the human in the loop stage | |
| #### ----------------------------------------------------------------------------------------------- | |
| # combined <- bind_rows(train_probs, test_probs) | |
| # ambiguous <- combined %>% | |
| # filter(.pred_class != ecthr_label) %>% | |
| # mutate(pred_class_abs_diff = abs((.pred_0 - .pred_1))) %>% | |
| # arrange(pred_class_abs_diff) %>% | |
| # slice(1:N_CASES) | |
| ambiguous <- unlab_probs %>% | |
| mutate(ecthr_label = .pred_class) %>% | |
| mutate(pred_class_abs_diff = abs((.pred_0 - .pred_1))) %>% | |
| arrange(pred_class_abs_diff) %>% | |
| slice(1:N_CASES) | |
| ambiguous_csv <- paste(ambiguous$dyad_id, collapse = ",") | |
| # export | |
| sink(TO_RECODE_DOC) | |
| cat(ambiguous_csv) | |
| sink() | |
| #### Update the dataframes | |
| #### ------------------------------------------------------------------------------------------------------------------ | |
| ### prepare the df | |
| to_add <- ambiguous %>% | |
| select(dyad_id, ecthr_label) %>% | |
| mutate(ecthr_label = as.numeric(ecthr_label)) | |
| unlabeled_all <- read_csv(UNLABELED_DF_PATH) | |
| df_2 <- inner_join(to_add, unlabeled_all) | |
| ### remove these rows from the unlabeled df | |
| new_unlabeled <- anti_join(unlabeled_all, df_2) | |
| write_csv(new_unlabeled, UNLABELED_DF_PATH) | |
| ### add it to the recoded df | |
| labeled_all <- read_csv(LABELED_DF_PATH) | |
| new_recoded <- bind_rows(labeled_all, df_2) | |
| write_csv(new_recoded, LABELED_DF_PATH) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment