Created
February 3, 2025 08:01
-
-
Save spinningcat/6a90c8f2ac5e8b12692fc77766330196 to your computer and use it in GitHub Desktop.
classification.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import pipeline | |
# Load a pre-trained NER model for biomedical entities | |
ner_model = "Helios9/BioMed_NER" # Example model for biomedical NER | |
ner_pipeline = pipeline("ner", model=ner_model) | |
# Load a pre-trained zero-shot classification model | |
zero_shot_model = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33" # Example zero-shot model | |
classifier_pipeline = pipeline("zero-shot-classification", model=zero_shot_model) | |
# Sample longer medical text (one continuous string) | |
text = """Patient John Doe, a 45-year-old male, was diagnosed with hypertension and prescribed | |
50mg of Lisinopril daily. He also has a history of type 2 diabetes and reports experiencing | |
occasional chest pain. During his last visit, his blood pressure was recorded at 160/100 mmHg. | |
The doctor advised him to monitor his blood sugar levels regularly and scheduled a follow-up | |
appointment in three months. Additionally, he was referred to a nutritionist for dietary counseling.""" | |
# Run NER to extract entities | |
entities = ner_pipeline(text) | |
# Define candidate labels for classification | |
labels = ["Description", "Disease", "Medication", "Diagnosis", "Treatment"] | |
classification_result = classifier_pipeline(text, candidate_labels=labels) | |
# Print results | |
print("Extracted Entities:") | |
print(entities) | |
print("\nClassification Result:") | |
print(classification_result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment