Skip to content

Instantly share code, notes, and snippets.

@nickfox-taterli
Created March 13, 2025 04:20
Show Gist options
  • Save nickfox-taterli/bd2dd5aed08443160fa84dc23f8aea5c to your computer and use it in GitHub Desktop.
Save nickfox-taterli/bd2dd5aed08443160fa84dc23f8aea5c to your computer and use it in GitHub Desktop.
RKNPU 训练,转换,测试
import cv2
import numpy as np
import platform
from synset_label import labels
from rknnlite.api import RKNNLite
def preprocess_image(image_path):
"""预处理单张图片"""
img = cv2.imread(image_path)
img = cv2.resize(img, (224, 224)) # 调整大小
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换颜色空间
img = np.expand_dims(img, 0) # 添加 batch 维度
return img
def predict_image(rknn, img):
"""对单张图片进行推理"""
outputs = rknn_lite.inference(inputs=[img], data_format=['nhwc'])
return np.argmax(outputs[0]) # 返回预测的类别索引
def evaluate_accuracy(rknn, X_test, y_test):
"""评估模型在测试集上的准确率"""
correct = 0
total = len(X_test)
for i in tqdm(range(total)):
image_path = X_test[i]
true_label = y_test[i]
# 预处理图片
img = preprocess_image(image_path)
# 推理
pred_label = predict_image(rknn, img)
# 比较预测结果与真实标签
if pred_label == true_label:
correct += 1
accuracy = correct / total
return accuracy
def load_dataset():
data_dir = '/mnt/kagglehub/datasets/alxmamaev/flowers-recognition/versions/2/flowers/'
# 获取所有图片路径和对应的标签
image_paths = []
labels = []
for class_name in os.listdir(data_dir):
class_dir = os.path.join(data_dir, class_name)
if os.path.isdir(class_dir):
for image_name in os.listdir(class_dir):
image_paths.append(os.path.join(class_dir, image_name))
labels.append(class_name)
# 将标签转换为整数
class_names = sorted(list(set(labels)))
label_to_index = {class_name: i for i, class_name in enumerate(class_names)}
labels = [label_to_index[label] for label in labels]
# 使用 train_test_split 划分数据集
return [image_paths,labels]
if __name__ == '__main__':
rknn_lite = RKNNLite()
# Load RKNN model
print('--> Load RKNN model')
ret = rknn_lite.load_rknn("mobilenet.rknn")
if ret != 0:
print('Load RKNN model failed')
exit(ret)
print('done')
# Init runtime environment
print('--> Init runtime environment')
# For RK3576 / RK3588, specify which NPU core the model runs on through the core_mask parameter.
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
if ret != 0:
print('Init runtime environment failed')
exit(ret)
print('done')
# Inference
print('--> Running model')
# 计算测试集准确率
X_test,y_test = load_dataset()
accuracy = evaluate_accuracy(rknn, X_test, y_test)
print(f"模型在测试集上的准确率: {accuracy * 100:.2f}%")
print('done')
rknn_lite.release()
import numpy as np
import cv2
import os
from rknn.api import RKNN
from tqdm import tqdm # 用于显示进度条
import kagglehub
from sklearn.model_selection import train_test_split
def preprocess_image(image_path):
"""预处理单张图片"""
img = cv2.imread(image_path)
img = cv2.resize(img, (224, 224)) # 调整大小
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换颜色空间
img = np.expand_dims(img, 0) # 添加 batch 维度
return img
def predict_image(rknn, img):
"""对单张图片进行推理"""
outputs = rknn.inference(inputs=[img], data_format=['nhwc'])
return np.argmax(outputs[0]) # 返回预测的类别索引
def evaluate_accuracy(rknn, X_test, y_test):
"""评估模型在测试集上的准确率"""
correct = 0
total = len(X_test)
for i in tqdm(range(total)):
image_path = X_test[i]
true_label = y_test[i]
# 预处理图片
img = preprocess_image(image_path)
# 推理
pred_label = predict_image(rknn, img)
# 比较预测结果与真实标签
if pred_label == true_label:
correct += 1
accuracy = correct / total
return accuracy
def load_dataset():
alxmamaev_flowers_recognition_path = kagglehub.dataset_download('alxmamaev/flowers-recognition')
data_dir = alxmamaev_flowers_recognition_path + '/flowers'
# 获取所有图片路径和对应的标签
image_paths = []
labels = []
for class_name in os.listdir(data_dir):
class_dir = os.path.join(data_dir, class_name)
if os.path.isdir(class_dir):
for image_name in os.listdir(class_dir):
image_paths.append(os.path.join(class_dir, image_name))
labels.append(class_name)
# 将标签转换为整数
class_names = sorted(list(set(labels)))
label_to_index = {class_name: i for i, class_name in enumerate(class_names)}
labels = [label_to_index[label] for label in labels]
# 使用 train_test_split 划分数据集
return [image_paths,labels]
if __name__ == '__main__':
target_size = (224, 224)
batch_size = 32
image_paths,labels = load_dataset()
X_train, X_test, y_train, y_test = train_test_split(
image_paths, labels, test_size=0.2, stratify=labels
)
# Create RKNN object
rknn = RKNN(verbose=True)
# Pre-process config
print('--> Config model')
rknn.config(dynamic_input=[[[1,224,224,3]]],mean_values=[128, 128, 128], std_values=[128, 128, 128], target_platform='rk3588')
print('done')
# Load model (from https://www.tensorflow.org/lite/guide/hosted_models?hl=zh-cn)
print('--> Loading model')
ret = rknn.load_tflite(model='mobilenet.tflite')
if ret != 0:
print('Load model failed!')
exit(ret)
print('done')
# Build model
print('--> Building model')
ret = rknn.build(do_quantization=True,dataset="dataset.txt")
if ret != 0:
print('Build model failed!')
exit(ret)
print('done')
# Export rknn model
print('--> Export rknn model')
ret = rknn.export_rknn('./mobilenet.rknn')
if ret != 0:
print('Export rknn model failed!')
exit(ret)
print('done')
# Init runtime environment
print('--> Init runtime environment')
ret = rknn.init_runtime()
if ret != 0:
print('Init runtime environment failed!')
exit(ret)
print('done')
# 计算测试集准确率
accuracy = evaluate_accuracy(rknn, X_test, y_test)
print(f"模型在测试集上的准确率: {accuracy * 100:.2f}%")
rknn.release()
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment