Created
May 2, 2023 14:25
-
-
Save aimerneige/3bc8b1e85642f0efabb3d2b7bd3129d8 to your computer and use it in GitHub Desktop.
baidu ocr pyqt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/env/bin python3 | |
# -*- coding: utf-8 -*- | |
# 读取系统文件用 | |
import sys | |
# base64 编码 | |
import base64 | |
# 网络请求库 | |
import requests | |
# 枚举类型 | |
from enum import Enum | |
# 图像处理 | |
import cv2 | |
# PyQt 框架 | |
from PyQt5 import QtCore | |
from PyQt5.QtGui import QPixmap | |
from PyQt5.QtCore import QSize, pyqtSlot | |
from PyQt5.QtWidgets import QApplication, QDesktopWidget, QMainWindow, QPushButton, QLabel, QTextEdit, QComboBox, QFileDialog, QMessageBox | |
# 窗口标题 | |
window_title = "文本识别" | |
API_KEY = "YOUR_API_KEY_HERE" | |
SECRET_KEY = "YOUR_SECRET_KEY_HERE" | |
TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' | |
OCR_URL = 'https://aip.baidubce.com/rest/2.0/ocr/v1/' | |
# OCR 语言接口参数封装 | |
class Language(Enum): | |
auto_detect = 0 | |
CHN_ENG = 1 | |
ENG = 2 | |
JAP = 3 | |
KOR = 4 | |
FRE = 5 | |
SPA = 6 | |
POR = 7 | |
GER = 8 | |
ITA = 9 | |
RUS = 10 | |
DAN = 11 | |
DUT = 12 | |
MAL = 13 | |
SWE = 14 | |
IND = 15 | |
POL = 16 | |
ROM = 17 | |
TUR = 18 | |
GRE = 19 | |
HUN = 20 | |
# OCR 调用封装类 | |
class OCR(object): | |
def __init__(self, api_key, secret_key): | |
super().__init__() | |
self.API_KEY = api_key | |
self.SECRET_KEY = secret_key | |
self.OCR_TYPE = "通用文字" | |
self.ACCESS_TOKEN = self.fetch_token() | |
self.LANGUAGE = Language.auto_detect.name | |
self.DETECT_DIRECTION = "false" | |
self.PARAGRAPH = "false" | |
self.PROBABILITY = "true" | |
def set_language(self, language): | |
self.LANGUAGE = language.name | |
def set_detect_direction(self, detect_direction): | |
self.DETECT_DIRECTION = detect_direction | |
def set_paragraph(self, paragraph): | |
self.PARAGRAPH = paragraph | |
def set_probability(self, probability): | |
self.PROBABILITY = probability | |
# 获取 Token | |
def fetch_token(self): | |
response = requests.post(TOKEN_URL, data={ | |
'grant_type': 'client_credentials', | |
'client_id': self.API_KEY, | |
'client_secret': self.SECRET_KEY | |
}) | |
if response: | |
return response.json()['access_token'] | |
# 编码图片文件 | |
def encode_image(self, image_path): | |
with open(image_path, 'rb') as f: | |
image_data = f.read() | |
return base64.b64encode(image_data) | |
# 编码 PDF 文件 | |
def encode_pdf(self, pdf_path): | |
with open(pdf_path, 'rb') as f: | |
pdf_data = f.read() | |
def accurate_basic(self, image_path): | |
ocr_url = OCR_URL + "accurate_basic" | |
request_url = ocr_url + '?access_token=' + self.ACCESS_TOKEN | |
headers = {'Content-Type': 'application/x-www-form-urlencoded'} | |
params = { | |
'image': self.encode_image(image_path), | |
'language_type': self.LANGUAGE, | |
'detect_direction': self.DETECT_DIRECTION, | |
'paragraph': self.PARAGRAPH, | |
'probability': self.PROBABILITY, | |
} | |
json_data = requests.post( | |
request_url, headers=headers, data=params).json() | |
result_text = "" | |
for paragraphs in json_data["paragraphs_result"]: | |
for index in paragraphs["words_result_idx"]: | |
result_text += json_data["words_result"][index]["words"] | |
result_text += " " | |
result_text += "\n" | |
return result_text | |
# 数字识别 | |
def numbers(self, image_path): | |
ocr_url = OCR_URL + "numbers" | |
request_url = ocr_url + "?access_token=" + self.ACCESS_TOKEN | |
headers = {'content-type': 'application/x-www-form-urlencoded'} | |
params = { | |
'image': self.encode_image(image_path), | |
'detect_direction': self.DETECT_DIRECTION, | |
} | |
json_data = requests.post( | |
request_url, headers=headers, data=params).json() | |
result_text = "" | |
for result in json_data["words_result"]: | |
numbers = result['words'] | |
result_text += numbers | |
result_text += "\n" | |
return result_text | |
# 手写识别 | |
def handwriting(self, image_path): | |
ocr_url = OCR_URL + "handwriting" | |
request_url = ocr_url + "?access_token=" + self.ACCESS_TOKEN | |
headers = {'content-type': 'application/x-www-form-urlencoded'} | |
params = { | |
'image': self.encode_image(image_path), | |
'detect_direction': self.DETECT_DIRECTION, | |
'probability': self.PROBABILITY, | |
} | |
json_data = requests.post( | |
request_url, headers=headers, data=params).json() | |
result_text = "" | |
for result in json_data["words_result"]: | |
words = result['words'] | |
result_text += words | |
result_text += "\n" | |
return result_text | |
# 获得 OCR 结果 | |
def get_ocr_result(self, image_path): | |
if self.OCR_TYPE == "通用文字": | |
return self.accurate_basic(image_path) | |
elif self.OCR_TYPE == "数字识别": | |
return self.numbers(image_path) | |
elif self.OCR_TYPE == "手写文字": | |
return self.handwriting(image_path) | |
# 窗口类 | |
class Window(QMainWindow): | |
def __init__(self): | |
super().__init__() | |
self.initWindow() | |
self.initUI() | |
self.initOCR() | |
self.center() | |
# 初始化窗口大小 | |
def initWindow(self): | |
self.setWindowTitle(window_title) | |
self.setFixedWidth(1280) | |
self.setFixedHeight(720) | |
# 初始化窗口界面 | |
def initUI(self): | |
self.initOCROptionSelection() | |
self.initImageSelection() | |
self.initTextResult() | |
# 初始化 OCR 类 | |
def initOCR(self): | |
self.OCR = OCR(API_KEY, SECRET_KEY) | |
self.OCR.set_detect_direction("false") | |
self.OCR.set_language(Language.CHN_ENG) | |
self.OCR.set_paragraph("true") | |
self.OCR.set_probability("false") | |
# OCR 识别类型控件 | |
def initOCROptionSelection(self): | |
self.ocrOptionLabel = QLabel("选择识别类型", self) | |
self.ocrOptionLabel.setFixedSize(QSize(220, 40)) | |
self.ocrOptionLabel.move(80, 80) | |
self.ocrOptionSelection = QComboBox(self) | |
self.ocrOptionSelection.move(320, 80) | |
self.ocrOptionSelection.setFixedSize(QSize(240, 40)) | |
self.ocrOptionSelection.addItem("通用文字") | |
self.ocrOptionSelection.addItem("数字识别") | |
self.ocrOptionSelection.addItem("手写文字") | |
# 绑定修改事件 | |
self.ocrOptionSelection.currentIndexChanged.connect( | |
self.selectionChange) | |
# 图片选择及演示控件 | |
def initImageSelection(self): | |
self.imageSelectionButton = QPushButton("选择需要识别的图片", self) | |
self.imageSelectionButton.setFixedSize(QSize(220, 80)) | |
self.imageSelectionButton.move(80, 180) | |
# 绑定点击事件 | |
self.imageSelectionButton.clicked.connect(self.imageSelectionClicked) | |
self.cameraSelectionButton = QPushButton("使用摄像头拍照", self) | |
self.cameraSelectionButton.setFixedSize(QSize(220, 80)) | |
self.cameraSelectionButton.move(340, 180) | |
# 绑定点击事件 | |
self.cameraSelectionButton.clicked.connect(self.cameraSelectionClicked) | |
self.imagePreviewImage = QLabel(self) | |
self.imagePreviewImage.setText("请选择要识别的图片") | |
self.imagePreviewImage.setStyleSheet( | |
"QLabel { background-color : gray; color : black; }") | |
self.imagePreviewImage.setAlignment(QtCore.Qt.AlignCenter) | |
self.imagePreviewImage.setFixedSize(QSize(480, 320)) | |
self.imagePreviewImage.setScaledContents(True) | |
self.imagePreviewImage.move(80, 320) | |
# 识别结果控件 | |
def initTextResult(self): | |
self.textResult = QTextEdit(self) | |
self.textResult.setFixedSize(QSize(480, 440)) | |
self.textResult.move(720, 80) | |
self.textCopy = QPushButton("复制到剪切板", self) | |
self.textCopy.setFixedSize(QSize(480, 80)) | |
self.textCopy.move(720, 560) | |
# 绑定点击事件 | |
self.textCopy.clicked.connect(self.copyClicked) | |
# 窗口居中 | |
def center(self): | |
qr = self.frameGeometry() | |
cp = QDesktopWidget().availableGeometry().center() | |
qr.moveCenter(cp) | |
self.move(qr.topLeft()) | |
# 识别类型选择 | |
def selectionChange(self, i): | |
self.OCR.OCR_TYPE = self.ocrOptionSelection.currentText() | |
# 选择系统图片 | |
@pyqtSlot() | |
def imageSelectionClicked(self): | |
selected_file = QFileDialog.getOpenFileName(self, "选择你要识别的图片", "~/") | |
file_path = selected_file[0] | |
self.imagePreviewImage.setPixmap(QPixmap(file_path)) | |
self.callOCR(file_path) | |
@pyqtSlot() | |
def cameraSelectionClicked(self): | |
cam = cv2.VideoCapture(0) | |
_, img = cam.read() | |
cam_img_path = "./camera_temp.png" | |
cv2.imwrite(cam_img_path, img) | |
self.imagePreviewImage.setPixmap(QPixmap(cam_img_path)) | |
self.callOCR(cam_img_path) | |
# 复制到剪切板 | |
@pyqtSlot() | |
def copyClicked(self): | |
print("copy") | |
QApplication.clipboard().setText(self.textResult.toPlainText()) | |
msg = QMessageBox(self) | |
msg.setText('已复制到剪切板') | |
msg.exec_() | |
# 调用 OCR 接口 | |
def callOCR(self, image_path): | |
ocr_result = self.OCR.get_ocr_result(image_path) | |
self.textResult.setPlainText(ocr_result) | |
def main(): | |
app = QApplication(sys.argv) | |
window = Window() | |
window.show() | |
sys.exit(app.exec_()) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment