Last active
April 7, 2020 18:43
-
-
Save philschmid/ae631c425c9d969ed18beb3a03229089 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from simpletransformers.classification import ClassificationModel | |
| from sklearn.model_selection import KFold | |
| from sklearn.metrics import accuracy_score | |
| import pandas as pd | |
| # Dataset | |
| dataset = [["Example sentence belonging to class 1", 1], | |
| ["Example sentence belonging to class 0", 0], | |
| ["Example eval sentence belonging to class 1", 1], | |
| ["Example eval sentence belonging to class 0", 0]] | |
| train_data = pd.DataFrame(dataset) | |
| # prepare cross validation | |
| n=5 | |
| kf = KFold(n_splits=n, random_state=seed, shuffle=True) | |
| results = [] | |
| for train_index, val_index in kf.split(train_data): | |
| # splitting Dataframe (dataset not included) | |
| train_df = train_data.iloc[train_index] | |
| val_df = train_data.iloc[val_index] | |
| # Defining Model | |
| model = ClassificationModel('bert', 'bert-base-uncased') | |
| # train the model | |
| model.train_model(train_df) | |
| # validate the model | |
| result, model_outputs, wrong_predictions = model.eval_model(val_df, acc=accuracy_score) | |
| print(result['acc']) | |
| # append model score | |
| results.append(result['acc']) | |
| print("results",results) | |
| print(f"Mean-Precision: {sum(results) / len(results)}") | |
| #>>> Result | |
| # 0.8535784635587655 | |
| # 0.8509520682862771 | |
| # 0.855548260013132 | |
| # 0.8272010512483574 | |
| # 0.8212877792378449 | |
| #results [0.8535784635587655,0.8509520682862771,0.855548260013132, | |
| # 0.8272010512483574,0.8212877792378449] | |
| # Mean-Precision: 0.8407520682862771 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment