Created
March 13, 2025 04:20
-
-
Save nickfox-taterli/bd2dd5aed08443160fa84dc23f8aea5c to your computer and use it in GitHub Desktop.
RKNPU 训练,转换,测试
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import numpy as np | |
import platform | |
from synset_label import labels | |
from rknnlite.api import RKNNLite | |
def preprocess_image(image_path): | |
"""预处理单张图片""" | |
img = cv2.imread(image_path) | |
img = cv2.resize(img, (224, 224)) # 调整大小 | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换颜色空间 | |
img = np.expand_dims(img, 0) # 添加 batch 维度 | |
return img | |
def predict_image(rknn, img): | |
"""对单张图片进行推理""" | |
outputs = rknn_lite.inference(inputs=[img], data_format=['nhwc']) | |
return np.argmax(outputs[0]) # 返回预测的类别索引 | |
def evaluate_accuracy(rknn, X_test, y_test): | |
"""评估模型在测试集上的准确率""" | |
correct = 0 | |
total = len(X_test) | |
for i in tqdm(range(total)): | |
image_path = X_test[i] | |
true_label = y_test[i] | |
# 预处理图片 | |
img = preprocess_image(image_path) | |
# 推理 | |
pred_label = predict_image(rknn, img) | |
# 比较预测结果与真实标签 | |
if pred_label == true_label: | |
correct += 1 | |
accuracy = correct / total | |
return accuracy | |
def load_dataset(): | |
data_dir = '/mnt/kagglehub/datasets/alxmamaev/flowers-recognition/versions/2/flowers/' | |
# 获取所有图片路径和对应的标签 | |
image_paths = [] | |
labels = [] | |
for class_name in os.listdir(data_dir): | |
class_dir = os.path.join(data_dir, class_name) | |
if os.path.isdir(class_dir): | |
for image_name in os.listdir(class_dir): | |
image_paths.append(os.path.join(class_dir, image_name)) | |
labels.append(class_name) | |
# 将标签转换为整数 | |
class_names = sorted(list(set(labels))) | |
label_to_index = {class_name: i for i, class_name in enumerate(class_names)} | |
labels = [label_to_index[label] for label in labels] | |
# 使用 train_test_split 划分数据集 | |
return [image_paths,labels] | |
if __name__ == '__main__': | |
rknn_lite = RKNNLite() | |
# Load RKNN model | |
print('--> Load RKNN model') | |
ret = rknn_lite.load_rknn("mobilenet.rknn") | |
if ret != 0: | |
print('Load RKNN model failed') | |
exit(ret) | |
print('done') | |
# Init runtime environment | |
print('--> Init runtime environment') | |
# For RK3576 / RK3588, specify which NPU core the model runs on through the core_mask parameter. | |
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0) | |
if ret != 0: | |
print('Init runtime environment failed') | |
exit(ret) | |
print('done') | |
# Inference | |
print('--> Running model') | |
# 计算测试集准确率 | |
X_test,y_test = load_dataset() | |
accuracy = evaluate_accuracy(rknn, X_test, y_test) | |
print(f"模型在测试集上的准确率: {accuracy * 100:.2f}%") | |
print('done') | |
rknn_lite.release() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import cv2 | |
import os | |
from rknn.api import RKNN | |
from tqdm import tqdm # 用于显示进度条 | |
import kagglehub | |
from sklearn.model_selection import train_test_split | |
def preprocess_image(image_path): | |
"""预处理单张图片""" | |
img = cv2.imread(image_path) | |
img = cv2.resize(img, (224, 224)) # 调整大小 | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换颜色空间 | |
img = np.expand_dims(img, 0) # 添加 batch 维度 | |
return img | |
def predict_image(rknn, img): | |
"""对单张图片进行推理""" | |
outputs = rknn.inference(inputs=[img], data_format=['nhwc']) | |
return np.argmax(outputs[0]) # 返回预测的类别索引 | |
def evaluate_accuracy(rknn, X_test, y_test): | |
"""评估模型在测试集上的准确率""" | |
correct = 0 | |
total = len(X_test) | |
for i in tqdm(range(total)): | |
image_path = X_test[i] | |
true_label = y_test[i] | |
# 预处理图片 | |
img = preprocess_image(image_path) | |
# 推理 | |
pred_label = predict_image(rknn, img) | |
# 比较预测结果与真实标签 | |
if pred_label == true_label: | |
correct += 1 | |
accuracy = correct / total | |
return accuracy | |
def load_dataset(): | |
alxmamaev_flowers_recognition_path = kagglehub.dataset_download('alxmamaev/flowers-recognition') | |
data_dir = alxmamaev_flowers_recognition_path + '/flowers' | |
# 获取所有图片路径和对应的标签 | |
image_paths = [] | |
labels = [] | |
for class_name in os.listdir(data_dir): | |
class_dir = os.path.join(data_dir, class_name) | |
if os.path.isdir(class_dir): | |
for image_name in os.listdir(class_dir): | |
image_paths.append(os.path.join(class_dir, image_name)) | |
labels.append(class_name) | |
# 将标签转换为整数 | |
class_names = sorted(list(set(labels))) | |
label_to_index = {class_name: i for i, class_name in enumerate(class_names)} | |
labels = [label_to_index[label] for label in labels] | |
# 使用 train_test_split 划分数据集 | |
return [image_paths,labels] | |
if __name__ == '__main__': | |
target_size = (224, 224) | |
batch_size = 32 | |
image_paths,labels = load_dataset() | |
X_train, X_test, y_train, y_test = train_test_split( | |
image_paths, labels, test_size=0.2, stratify=labels | |
) | |
# Create RKNN object | |
rknn = RKNN(verbose=True) | |
# Pre-process config | |
print('--> Config model') | |
rknn.config(dynamic_input=[[[1,224,224,3]]],mean_values=[128, 128, 128], std_values=[128, 128, 128], target_platform='rk3588') | |
print('done') | |
# Load model (from https://www.tensorflow.org/lite/guide/hosted_models?hl=zh-cn) | |
print('--> Loading model') | |
ret = rknn.load_tflite(model='mobilenet.tflite') | |
if ret != 0: | |
print('Load model failed!') | |
exit(ret) | |
print('done') | |
# Build model | |
print('--> Building model') | |
ret = rknn.build(do_quantization=True,dataset="dataset.txt") | |
if ret != 0: | |
print('Build model failed!') | |
exit(ret) | |
print('done') | |
# Export rknn model | |
print('--> Export rknn model') | |
ret = rknn.export_rknn('./mobilenet.rknn') | |
if ret != 0: | |
print('Export rknn model failed!') | |
exit(ret) | |
print('done') | |
# Init runtime environment | |
print('--> Init runtime environment') | |
ret = rknn.init_runtime() | |
if ret != 0: | |
print('Init runtime environment failed!') | |
exit(ret) | |
print('done') | |
# 计算测试集准确率 | |
accuracy = evaluate_accuracy(rknn, X_test, y_test) | |
print(f"模型在测试集上的准确率: {accuracy * 100:.2f}%") | |
rknn.release() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment