Last active
February 26, 2024 01:19
-
-
Save fuzzie360/8652d7e45201edab060f5aeb2d8b4bb9 to your computer and use it in GitHub Desktop.
Document QA sample python programs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import pipeline | |
from PIL import Image | |
pipe = pipeline("document-question-answering", model="naver-clova-ix/donut-base-finetuned-docvqa") | |
question = "For what period is this payslip for?" | |
image = Image.open("./input.jpg") | |
output = pipe(image=image, question=question) | |
print(output) | |
# [{'answer': 'end-may 2022'}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification | |
import torch | |
from PIL import Image | |
from typing import List | |
DEVICE = "cpu" | |
if torch.cuda.is_available(): | |
DEVICE = "cuda:0" | |
elif torch.backends.mps.is_available(): | |
DEVICE = "mps" | |
model = LayoutLMv3ForSequenceClassification.from_pretrained( | |
"nielsr/layoutlmv3-finetuned-cord" | |
) | |
model = model.eval().to(DEVICE) | |
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=True) | |
tokenizer = LayoutLMv3TokenizerFast.from_pretrained( | |
"nielsr/layoutlmv3-finetuned-cord" | |
) | |
processor = LayoutLMv3Processor(feature_extractor, tokenizer) | |
image_path = "input.jpg" | |
with Image.open(image_path).convert("RGB") as image: | |
encoding = processor( | |
image, | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt", | |
) | |
with torch.inference_mode(): | |
output = model( | |
input_ids=encoding["input_ids"].to(DEVICE), | |
attention_mask=encoding["attention_mask"].to(DEVICE), | |
bbox=encoding["bbox"].to(DEVICE), | |
pixel_values=encoding["pixel_values"].to(DEVICE) | |
) | |
predicted_class = output.logits.argmax() | |
predicted_class_label = model.config.id2label[predicted_class.item()] | |
print(predicted_class_label) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoProcessor, AutoModelForQuestionAnswering | |
from PIL import Image | |
import torch | |
import easyocr | |
processor = AutoProcessor.from_pretrained("rubentito/layoutlmv3-base-mpdocvqa", apply_ocr=False) | |
model = AutoModelForQuestionAnswering.from_pretrained("rubentito/layoutlmv3-base-mpdocvqa") | |
image = Image.open( | |
"input.jpg" | |
).convert("RGB") | |
width, height = image.size | |
width_scale = 1000 / width | |
height_scale = 1000 / height | |
reader = easyocr.Reader(['en']) | |
ocr_result = reader.readtext(str('./input.jpg')) | |
def create_bounding_box(bbox_data): | |
xs = [] | |
ys = [] | |
for x, y in bbox_data: | |
xs.append(x) | |
ys.append(y) | |
left = int(min(xs)) | |
top = int(min(ys)) | |
right = int(max(xs)) | |
bottom = int(max(ys)) | |
return [left, top, right, bottom] | |
def scale_bounding_box(box, width_scale = 1.0, height_scale = 1.0): | |
return [ | |
int(box[0] * width_scale), | |
int(box[1] * height_scale), | |
int(box[2] * width_scale), | |
int(box[3] * height_scale) | |
] | |
ocr_page = [] | |
for bbox, word, confidence in ocr_result: | |
ocr_page.append({ | |
"word": word, "bounding_box": create_bounding_box(bbox) | |
}) | |
words = [] | |
boxes = [] | |
for row in ocr_page: | |
boxes.append( | |
scale_bounding_box( | |
row["bounding_box"], | |
width_scale, | |
height_scale | |
) | |
) | |
words.append(row["word"]) | |
question = "What is the bank account number?" | |
encoding = processor(image, question, words, boxes=boxes, return_tensors="pt") | |
with torch.inference_mode(): | |
outputs = model( | |
**encoding | |
) | |
predicted_start_idx = outputs.start_logits.argmax(-1).item() | |
predicted_end_idx = outputs.end_logits.argmax(-1).item() | |
predicted_start_idx, predicted_end_idx | |
predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1] | |
predicted_answer = processor.tokenizer.decode(predicted_answer_tokens) | |
print(predicted_answer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import pipeline | |
from PIL import Image | |
pipe = pipeline("document-question-answering", model="impira/layoutlm-document-qa") | |
question = "For what period is this payslip for?" | |
image = Image.open("./input.jpg") | |
output = pipe(image=image, question=question) | |
print(output) | |
# [{'score': 0.9198229312896729, 'answer': 'END-MAY 2022', 'start': 12, 'end': 13}] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering | |
import torch | |
from PIL import Image | |
from typing import List | |
processor = LayoutLMv3Processor.from_pretrained("rubentito/layoutlmv3-base-mpdocvqa", apply_ocr=True) | |
model = LayoutLMv3ForQuestionAnswering.from_pretrained("rubentito/layoutlmv3-base-mpdocvqa") | |
image_path = "input.jpg" | |
question = "For what period is this payslip for?" | |
with Image.open(image_path).convert("RGB") as image: | |
encoding = processor( | |
image, | |
question, | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt", | |
) | |
with torch.inference_mode(): | |
output = model( | |
**encoding | |
) | |
start_idx = torch.argmax(output.start_logits, axis=1) | |
end_idx = torch.argmax(output.end_logits, axis=1) | |
predict_answer_tokens = encoding.input_ids[0, start_idx : end_idx + 1] | |
print(processor.decode(predict_answer_tokens, skip_special_tokens=True)) | |
# O1/MAY/2022 TO 31/MAY/2022 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
transformers | |
Pillow | |
sentencepiece |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment