Created
November 12, 2024 01:54
-
-
Save shaobin0604/4b9e650228abbcda456ceb63bed19128 to your computer and use it in GitHub Desktop.
jpeg decode speed test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import timeit | |
import unittest | |
import cv2 | |
import torch | |
import torchvision.io as io | |
from turbojpeg import TurboJPEG | |
from nvjpeg import NvJpeg | |
import numpy as np | |
class TestJpegDecode(unittest.TestCase): | |
def setUp(self): | |
self._repeat = 10 | |
self._turbo_jpeg = TurboJPEG() | |
self._nv_jpeg = NvJpeg() | |
# 获取当前文件所在的目录 | |
current_directory = os.path.dirname(os.path.abspath(__file__)) | |
# 拼接得到test.jpg文件的绝对路径 | |
test_jpg_path = os.path.join(current_directory, 'jpeg_decode_test.jpg') | |
with open(test_jpg_path, "rb") as f: | |
self.jpeg_bytes = f.read() | |
def test_jpeg_decode(self): | |
def turbo_jpeg_decode(jpeg_bytes: bytes) -> np.ndarray: | |
return self._turbo_jpeg.decode(jpeg_bytes) | |
def nv_jpeg_decode(jpeg_bytes: bytes) -> np.ndarray: | |
return self._nv_jpeg.decode(jpeg_bytes) | |
def cv_jpeg_decode(jpeg_bytes: bytes) -> np.ndarray: | |
return cv2.imdecode( | |
np.frombuffer(jpeg_bytes, dtype=np.uint8), cv2.IMREAD_COLOR | |
) | |
def torch_vision_jpeg_decode(jpeg_bytes: bytes): | |
if torch.cuda.is_available(): | |
io.decode_jpeg(torch.frombuffer(jpeg_bytes, dtype=torch.uint8), device="cuda") | |
ret = timeit.timeit( | |
lambda: turbo_jpeg_decode(self.jpeg_bytes), number=self._repeat | |
) | |
print("TurboJPEG decode time: ", ret / self._repeat) | |
ret = timeit.timeit( | |
lambda: nv_jpeg_decode(self.jpeg_bytes), number=self._repeat | |
) | |
print("NvJPEG decode time: ", ret / self._repeat) | |
ret = timeit.timeit( | |
lambda: cv_jpeg_decode(self.jpeg_bytes), number=self._repeat | |
) | |
print("CV2 decode time: ", ret / self._repeat) | |
ret = timeit.timeit( | |
lambda: torch_vision_jpeg_decode(self.jpeg_bytes), number=self._repeat | |
) | |
print("torch vision decode time: ", ret / self._repeat) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment