Skip to content

Instantly share code, notes, and snippets.

@vfdev-5
Last active March 29, 2023 16:34
Show Gist options
  • Save vfdev-5/1c0778904a07ce40401306548b9525e8 to your computer and use it in GitHub Desktop.
Save vfdev-5/1c0778904a07ce40401306548b9525e8 to your computer and use it in GitHub Desktop.
PyTorch, Improved perfs for vectorized interpolate uint8 RGB-case
Description:
- 20230329-174512-pr
Torch version: 2.1.0a0+gitd6e220c
Torch config: PyTorch built with:
  - GCC 9.4
  - C++ Version: 201703
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - Build settings: BUILD_TYPE=Release, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=1 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=0, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=0, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF, 


- 20230327-111746-nightly
Torch version: 2.1.0a0+git2b75955
Torch config: PyTorch built with:
  - GCC 9.4
  - C++ Version: 201703
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - CPU capability usage: AVX2
  - Build settings: BUILD_TYPE=Release, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=1 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_PYTORCH_QNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.0, USE_CUDA=0, USE_CUDNN=OFF, USE_EIGEN_FOR_BLAS=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=OFF, USE_MKLDNN=0, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=0, USE_OPENMP=ON, USE_ROCM=OFF, 



[-------------------------------------------------------------------------------------------------- Resize -------------------------------------------------------------------------------------------------]
                                                                                 |  Pillow (9.0.0.post1)  |  torch (2.1.0a0+gitd6e220c) PR  |  torch (2.1.0a0+git2b75955) nightly  |  Speed-up: PR vs nightly
1 threads: --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      3 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=True        |    38.674 (+-0.323)    |         57.591 (+-0.244)        |          131.033 (+-1.448)           |      2.275 (+-0.000)    
      3 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=False       |                        |         39.471 (+-0.166)        |          113.911 (+-1.736)           |      2.886 (+-0.000)    
      3 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=True      |   128.512 (+-1.916)    |        161.592 (+-1.242)        |          299.679 (+-2.099)           |      1.855 (+-0.000)    
      3 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=False     |                        |        150.994 (+-1.180)        |          285.331 (+-1.919)           |      1.890 (+-0.000)    
      3 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=True      |   180.045 (+-2.223)    |        220.581 (+-1.363)        |          431.057 (+-3.536)           |      1.954 (+-0.000)    
      3 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=False     |                        |        219.391 (+-1.409)        |          429.410 (+-3.620)           |      1.957 (+-0.000)    
      3 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=True        |   113.911 (+-1.024)    |        129.457 (+-1.295)        |          459.610 (+-13.322)          |      3.550 (+-0.000)    
      3 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=False       |                        |         59.800 (+-0.199)        |          400.015 (+-11.815)          |      6.689 (+-0.000)    
      3 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=True      |   283.050 (+-2.664)    |        339.143 (+-1.209)        |          683.555 (+-4.466)           |      2.016 (+-0.000)    
      3 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=False     |                        |        250.601 (+-1.236)        |          603.545 (+-2.644)           |      2.408 (+-0.000)    
      3 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=True        |   186.723 (+-2.213)    |        199.960 (+-1.343)        |          860.867 (+-21.763)          |      4.305 (+-0.000)    
      3 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=False       |                        |         79.188 (+-0.261)        |          703.019 (+-25.805)          |      8.878 (+-0.000)    
      3 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=True      |   412.353 (+-4.476)    |        462.230 (+-1.983)        |         1101.673 (+-49.299)          |      2.383 (+-0.000)    
      3 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=False     |                        |        327.973 (+-1.852)        |          941.062 (+-5.549)           |      2.869 (+-0.000)    
      3 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=True        |    61.191 (+-0.926)    |         80.795 (+-0.518)        |          160.853 (+-1.506)           |      1.991 (+-0.000)    
      3 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=True      |   134.488 (+-2.129)    |        169.147 (+-1.324)        |          327.343 (+-2.846)           |      1.935 (+-0.000)    
      3 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=True    |  1037.045 (+-24.982)   |        938.623 (+-9.010)        |         2603.360 (+-20.530)          |      2.774 (+-0.000)    
      3 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=True        |    52.792 (+-0.613)    |         73.692 (+-0.264)        |          131.829 (+-1.333)           |      1.789 (+-0.000)    
      3 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=True      |   139.596 (+-1.944)    |        173.778 (+-1.039)        |          320.063 (+-2.562)           |      1.842 (+-0.000)    
      3 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=True    |   690.132 (+-10.946)   |        772.758 (+-2.864)        |         2036.860 (+-36.109)          |      2.636 (+-0.000)    
      3 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=False       |                        |         78.747 (+-0.799)        |          158.479 (+-1.702)           |      2.013 (+-0.000)    
      3 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=False     |                        |        167.046 (+-1.077)        |          322.104 (+-2.764)           |      1.928 (+-0.000)    
      3 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=False   |                        |        918.967 (+-5.251)        |         2611.388 (+-29.917)          |      2.842 (+-0.000)    
      3 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=False       |                        |         55.336 (+-0.251)        |          113.869 (+-1.243)           |      2.058 (+-0.000)    
      3 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=False     |                        |        156.505 (+-1.095)        |          299.861 (+-2.710)           |      1.916 (+-0.000)    
      3 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=False   |                        |        514.344 (+-1.905)        |         1776.796 (+-19.660)          |      3.454 (+-0.000)    
      4 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=True        |                        |         52.315 (+-0.253)        |           51.626 (+-0.364)           |      0.987 (+-0.000)    
      4 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=False       |                        |         35.464 (+-0.192)        |           34.793 (+-0.720)           |      0.981 (+-0.000)    
      4 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=True      |                        |        144.356 (+-1.126)        |          145.787 (+-1.424)           |      1.010 (+-0.000)    
      4 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=False     |                        |        131.043 (+-0.720)        |          131.900 (+-1.170)           |      1.007 (+-0.000)    
      4 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=True      |                        |        194.268 (+-1.057)        |          195.768 (+-1.758)           |      1.008 (+-0.000)    
      4 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=False     |                        |        191.441 (+-0.945)        |          193.152 (+-1.448)           |      1.009 (+-0.000)    
      4 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=True        |                        |        119.895 (+-1.256)        |          124.296 (+-1.251)           |      1.037 (+-0.000)    
      4 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=False       |                        |         56.785 (+-0.134)        |           56.383 (+-0.298)           |      0.993 (+-0.000)    
      4 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=True      |                        |        286.033 (+-1.629)        |          290.807 (+-2.869)           |      1.017 (+-0.000)    
      4 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=False     |                        |        211.898 (+-1.080)        |          214.602 (+-1.619)           |      1.013 (+-0.000)    
      4 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=True        |                        |        186.008 (+-1.196)        |          197.220 (+-1.500)           |      1.060 (+-0.000)    
      4 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=False       |                        |         74.411 (+-0.164)        |           73.541 (+-0.226)           |      0.988 (+-0.000)    
      4 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=True      |                        |        416.270 (+-1.898)        |          428.137 (+-4.839)           |      1.029 (+-0.000)    
      4 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=False     |                        |        270.673 (+-1.320)        |          276.558 (+-1.940)           |      1.022 (+-0.000)    
      4 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=True        |                        |         77.987 (+-0.390)        |           75.166 (+-0.391)           |      0.964 (+-0.000)    
      4 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=True      |                        |        151.690 (+-1.061)        |          152.750 (+-1.371)           |      1.007 (+-0.000)    
      4 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=True    |                        |        897.172 (+-3.482)        |          906.450 (+-5.702)           |      1.010 (+-0.000)    
      4 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=True        |                        |         67.887 (+-0.245)        |           67.724 (+-0.451)           |      0.998 (+-0.000)    
      4 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=True      |                        |        155.851 (+-1.478)        |          156.827 (+-1.646)           |      1.006 (+-0.000)    
      4 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=True    |                        |        638.757 (+-3.986)        |          680.733 (+-4.411)           |      1.066 (+-0.000)    
      4 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=False       |                        |         76.281 (+-0.418)        |           73.980 (+-0.507)           |      0.970 (+-0.000)    
      4 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=False     |                        |        149.798 (+-1.094)        |          150.360 (+-1.209)           |      1.004 (+-0.000)    
      4 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=False   |                        |        889.140 (+-3.676)        |          888.575 (+-6.037)           |      0.999 (+-0.000)    
      4 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=False       |                        |         49.680 (+-0.160)        |           47.945 (+-0.264)           |      0.965 (+-0.000)    
      4 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=False     |                        |        136.919 (+-1.019)        |          134.807 (+-0.937)           |      0.985 (+-0.000)    
      4 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=False   |                        |        411.040 (+-1.962)        |          421.084 (+-2.218)           |      1.024 (+-0.000)    
      3 torch.uint8 channels_first bilinear (256, 256) -> (32, 32) aa=True       |    38.370 (+-0.192)    |        129.711 (+-0.827)        |          129.264 (+-1.054)           |      0.997 (+-0.000)    
      3 torch.uint8 channels_first bilinear (256, 256) -> (32, 32) aa=False      |                        |        112.856 (+-0.798)        |          111.873 (+-1.004)           |      0.991 (+-0.000)    
      3 torch.uint8 channels_first bilinear (256, 256) -> (224, 224) aa=True     |   127.894 (+-1.034)    |        297.265 (+-2.226)        |          317.713 (+-3.049)           |      1.069 (+-0.000)    
      3 torch.uint8 channels_first bilinear (256, 256) -> (224, 224) aa=False    |                        |        283.683 (+-1.334)        |          294.468 (+-8.713)           |      1.038 (+-0.000)    
      3 torch.uint8 channels_first bilinear (256, 256) -> (320, 320) aa=True     |   178.460 (+-1.591)    |        425.950 (+-1.716)        |          471.941 (+-5.183)           |      1.108 (+-0.000)    
      3 torch.uint8 channels_first bilinear (256, 256) -> (320, 320) aa=False    |                        |        423.476 (+-1.880)        |          463.577 (+-6.143)           |      1.095 (+-0.000)    
      3 torch.uint8 channels_first bilinear (520, 520) -> (32, 32) aa=True       |   112.312 (+-1.850)    |        427.963 (+-2.081)        |          433.944 (+-2.470)           |      1.014 (+-0.000)    
      3 torch.uint8 channels_first bilinear (520, 520) -> (32, 32) aa=False      |                        |        365.342 (+-1.370)        |          366.145 (+-2.277)           |      1.002 (+-0.000)    
      3 torch.uint8 channels_first bilinear (520, 520) -> (224, 224) aa=True     |   281.371 (+-1.823)    |        669.098 (+-1.935)        |          695.277 (+-3.210)           |      1.039 (+-0.000)    
      3 torch.uint8 channels_first bilinear (520, 520) -> (224, 224) aa=False    |                        |        595.256 (+-2.767)        |          610.557 (+-3.784)           |      1.026 (+-0.000)    
      3 torch.uint8 channels_first bilinear (712, 712) -> (32, 32) aa=True       |   185.920 (+-1.374)    |        761.921 (+-3.720)        |          773.219 (+-3.421)           |      1.015 (+-0.000)    
      3 torch.uint8 channels_first bilinear (712, 712) -> (32, 32) aa=False      |                        |        650.812 (+-2.438)        |          651.833 (+-3.869)           |      1.002 (+-0.000)    
      3 torch.uint8 channels_first bilinear (712, 712) -> (224, 224) aa=True     |   410.439 (+-2.593)    |        1065.288 (+-9.492)       |         1097.446 (+-14.107)          |      1.030 (+-0.000)    
      3 torch.uint8 channels_first bilinear (712, 712) -> (224, 224) aa=False    |                        |        921.811 (+-4.258)        |          942.783 (+-4.818)           |      1.023 (+-0.000)    
      3 torch.uint8 channels_first bilinear (64, 64) -> (224, 224) aa=True       |    60.586 (+-0.246)    |        160.184 (+-1.325)        |          183.886 (+-3.023)           |      1.148 (+-0.000)    
      3 torch.uint8 channels_first bilinear (224, 224) -> (270, 268) aa=True     |   132.609 (+-1.773)    |        320.246 (+-1.989)        |          346.158 (+-3.725)           |      1.081 (+-0.000)    
      3 torch.uint8 channels_first bilinear (256, 256) -> (1024, 1024) aa=True   |   943.126 (+-3.348)    |       2577.556 (+-16.597)       |         2942.516 (+-116.401)         |      1.142 (+-0.000)    
      3 torch.uint8 channels_first bilinear (224, 224) -> (64, 64) aa=True       |    52.089 (+-0.389)    |        132.051 (+-1.316)        |          132.548 (+-0.931)           |      1.004 (+-0.000)    
      3 torch.uint8 channels_first bilinear (270, 268) -> (224, 224) aa=True     |   138.987 (+-1.232)    |        314.340 (+-2.130)        |          338.116 (+-2.309)           |      1.076 (+-0.000)    
      3 torch.uint8 channels_first bilinear (1024, 1024) -> (256, 256) aa=True   |   692.526 (+-7.640)    |       1960.090 (+-25.304)       |         2020.990 (+-16.370)          |      1.031 (+-0.000)    
      3 torch.uint8 channels_first bilinear (64, 64) -> (224, 224) aa=False      |                        |        158.146 (+-1.176)        |          181.866 (+-2.786)           |      1.150 (+-0.000)    
      3 torch.uint8 channels_first bilinear (224, 224) -> (270, 268) aa=False    |                        |        317.544 (+-2.225)        |          349.501 (+-4.249)           |      1.101 (+-0.000)    
      3 torch.uint8 channels_first bilinear (256, 256) -> (1024, 1024) aa=False  |                        |       2603.797 (+-17.057)       |         2988.936 (+-146.450)         |      1.148 (+-0.000)    
      3 torch.uint8 channels_first bilinear (224, 224) -> (64, 64) aa=False      |                        |        113.148 (+-1.709)        |          113.022 (+-0.837)           |      0.999 (+-0.000)    
      3 torch.uint8 channels_first bilinear (270, 268) -> (224, 224) aa=False    |                        |        296.490 (+-1.489)        |          309.067 (+-3.466)           |      1.042 (+-0.000)    
      3 torch.uint8 channels_first bilinear (1024, 1024) -> (256, 256) aa=False  |                        |       1737.213 (+-14.597)       |         1760.707 (+-18.867)          |      1.014 (+-0.000)    
      4 torch.uint8 channels_first bilinear (256, 256) -> (32, 32) aa=True       |                        |         71.435 (+-0.223)        |           70.344 (+-0.248)           |      0.985 (+-0.000)    
      4 torch.uint8 channels_first bilinear (256, 256) -> (32, 32) aa=False      |                        |         54.579 (+-0.255)        |           53.607 (+-0.192)           |      0.982 (+-0.000)    
      4 torch.uint8 channels_first bilinear (256, 256) -> (224, 224) aa=True     |                        |        251.264 (+-1.459)        |          262.699 (+-3.071)           |      1.046 (+-0.000)    
      4 torch.uint8 channels_first bilinear (256, 256) -> (224, 224) aa=False    |                        |        238.140 (+-1.602)        |          253.042 (+-3.699)           |      1.063 (+-0.000)    
      4 torch.uint8 channels_first bilinear (256, 256) -> (320, 320) aa=True     |                        |        394.355 (+-1.676)        |          429.776 (+-7.453)           |      1.090 (+-0.000)    
      4 torch.uint8 channels_first bilinear (256, 256) -> (320, 320) aa=False    |                        |        393.164 (+-2.190)        |          424.774 (+-9.492)           |      1.080 (+-0.000)    
      4 torch.uint8 channels_first bilinear (520, 520) -> (32, 32) aa=True       |                        |        186.814 (+-1.007)        |          192.311 (+-1.715)           |      1.029 (+-0.000)    
      4 torch.uint8 channels_first bilinear (520, 520) -> (32, 32) aa=False      |                        |        125.245 (+-0.778)        |          124.377 (+-0.689)           |      0.993 (+-0.000)    
      4 torch.uint8 channels_first bilinear (520, 520) -> (224, 224) aa=True     |                        |        441.408 (+-2.111)        |          455.736 (+-2.886)           |      1.032 (+-0.000)    
      4 torch.uint8 channels_first bilinear (520, 520) -> (224, 224) aa=False    |                        |        367.833 (+-1.909)        |          384.434 (+-3.405)           |      1.045 (+-0.000)    
      4 torch.uint8 channels_first bilinear (712, 712) -> (32, 32) aa=True       |                        |        308.348 (+-1.214)        |          324.199 (+-2.901)           |      1.051 (+-0.000)    
      4 torch.uint8 channels_first bilinear (712, 712) -> (32, 32) aa=False      |                        |        200.152 (+-1.071)        |          200.237 (+-1.403)           |      1.000 (+-0.000)    
      4 torch.uint8 channels_first bilinear (712, 712) -> (224, 224) aa=True     |                        |        631.259 (+-5.656)        |          672.493 (+-3.950)           |      1.065 (+-0.000)    
      4 torch.uint8 channels_first bilinear (712, 712) -> (224, 224) aa=False    |                        |        484.197 (+-4.397)        |          524.986 (+-5.093)           |      1.084 (+-0.000)    
      4 torch.uint8 channels_first bilinear (64, 64) -> (224, 224) aa=True       |                        |        170.100 (+-1.051)        |          180.047 (+-4.952)           |      1.058 (+-0.000)    
      4 torch.uint8 channels_first bilinear (224, 224) -> (270, 268) aa=True     |                        |        294.919 (+-2.060)        |          314.463 (+-4.386)           |      1.066 (+-0.000)    
      4 torch.uint8 channels_first bilinear (256, 256) -> (1024, 1024) aa=True   |                        |       2884.749 (+-17.014)       |         3196.146 (+-101.697)         |      1.108 (+-0.000)    
      4 torch.uint8 channels_first bilinear (224, 224) -> (64, 64) aa=True       |                        |         88.335 (+-0.415)        |           88.856 (+-0.530)           |      1.006 (+-0.000)    
      4 torch.uint8 channels_first bilinear (270, 268) -> (224, 224) aa=True     |                        |        263.454 (+-1.483)        |          277.140 (+-2.539)           |      1.052 (+-0.000)    
      4 torch.uint8 channels_first bilinear (1024, 1024) -> (256, 256) aa=True   |                        |        1014.410 (+-8.548)       |         1347.051 (+-11.137)          |      1.328 (+-0.000)    
      4 torch.uint8 channels_first bilinear (64, 64) -> (224, 224) aa=False      |                        |        168.690 (+-2.082)        |          185.902 (+-3.013)           |      1.102 (+-0.000)    
      4 torch.uint8 channels_first bilinear (224, 224) -> (270, 268) aa=False    |                        |        292.200 (+-1.753)        |          319.375 (+-3.580)           |      1.093 (+-0.000)    
      4 torch.uint8 channels_first bilinear (256, 256) -> (1024, 1024) aa=False  |                        |       2884.541 (+-33.086)       |         3245.043 (+-113.882)         |      1.125 (+-0.000)    
      4 torch.uint8 channels_first bilinear (224, 224) -> (64, 64) aa=False      |                        |         69.959 (+-0.257)        |           69.877 (+-0.367)           |      0.999 (+-0.000)    
      4 torch.uint8 channels_first bilinear (270, 268) -> (224, 224) aa=False    |                        |        244.994 (+-1.161)        |          259.398 (+-3.671)           |      1.059 (+-0.000)    
      4 torch.uint8 channels_first bilinear (1024, 1024) -> (256, 256) aa=False  |                        |        789.698 (+-3.819)        |         1067.531 (+-30.590)          |      1.352 (+-0.000)    

Times are in microseconds (us).

  • Install Pillow-SIMD and other deps
pip uninstall -y pillow && CC="cc -mavx2" pip install --no-cache-dir --force-reinstall pillow-simd
pip install fire
wget https://raw.githubusercontent.com/pytorch/vision/main/torchvision/transforms/functional_tensor.py -O torchvision_functional_tensor.py
  • To run benchmarks
### On nightly
python -u run_bench_interp.py "output/$(date "+%Y%m%d-%H%M%S")-nightly.pkl" --tag=nightly

### On PR
python -u run_bench_interp.py "output/$(date "+%Y%m%d-%H%M%S")-pr.pkl" --tag=PR


python -u make_results_table_from_pickles.py output/$(date "+%Y%m%d-%H%M%S")-pr_vs_nightly.md output/XYZ-pr.pkl output/ABC-nightly.pkl

python -u perf_results_compute_speedup.py output/20230320-160044-pr_vs_nightly-speedup.md "['output/XYZ-pr.pkl', 'output/ABC-nightly.pkl']" --col1="torch (2.1.0a0+gitc005105) PR" --col2="torch (2.1.0a0+git5309c44) nightly" --description="Speed-up: PR vs nightly"
import pickle
from pathlib import Path
from typing import List, Optional
import unittest.mock
import torch
import torch.utils.benchmark as benchmark
from torch.utils.benchmark.utils import common
from torch.utils.benchmark.utils.compare import Table
import fire
def patched_as_column_strings(self):
concrete_results = [r for r in self._results if r is not None]
env = f"({concrete_results[0].env})" if self._render_env else ""
env = env.ljust(self._env_str_len + 4)
output = [" " + env + concrete_results[0].as_row_name]
for m, col in zip(self._results, self._columns or ()):
if m is None:
output.append(col.num_to_str(None, 1, None))
else:
if len(m.times) == 1:
spread = 0
else:
spread = float(torch.tensor(m.times, dtype=torch.float64).std(unbiased=len(m.times) > 1))
if col._trim_significant_figures:
spread = benchmark.utils.common.trim_sigfig(spread, m.significant_figures)
output.append(f"{m.median / self._time_scale:>3.3f} (+-{spread / self._time_scale:>3.3f})")
return output
class Value(common.Measurement): pass
class CustomizedTable(Table):
def __init__(self, results, colorize, trim_significant_figures, highlight_warnings):
assert len(set(r.label for r in results)) == 1
self.results = results
self._colorize = colorize
self._trim_significant_figures = trim_significant_figures
self._highlight_warnings = highlight_warnings
self.label = results[0].label
self.time_unit, self.time_scale = common.select_unit(
min(r.median for r in results if not isinstance(r, Value))
)
self.row_keys = common.ordered_unique([self.row_fn(i) for i in results])
self.row_keys.sort(key=lambda args: args[:2]) # preserve stmt order
self.column_keys = common.ordered_unique([self.col_fn(i) for i in results])
self.rows, self.columns = self.populate_rows_and_columns()
def get_new_table(compare, col1, col2, description, debug):
results = common.Measurement.merge(compare._results)
grouped_results = compare._group_by_label(results)
assert len(grouped_results.values()) == 1, grouped_results.values()
groups_iter = iter(grouped_results.values())
group = next(groups_iter)
if description is None:
description = f"Speed-up: {col1} vs {col2}"
# Add speed-up column into results:
updated_group = []
sub_label = None
v1 = None
v2 = None
r = None
_, scale = common.select_unit(min([r.median for r in group]))
for measurement in group:
if debug:
print("measurement.task_spec.description:", measurement.task_spec.description)
if measurement.task_spec.description == col1:
v1 = measurement.median
sub_label = measurement.task_spec.sub_label
if debug:
print("Matched col1:", col1, v1, sub_label)
measurement2 = None
for m2 in group:
d2 = m2.task_spec.description
sl2 = m2.task_spec.sub_label
if d2 == col2 and sl2 == sub_label:
v2 = m2.median
if debug:
print("Matched col2:", col2, v2)
measurement2 = m2
break
if measurement not in updated_group:
updated_group.append(measurement)
if v1 is not None and v2 is not None:
if measurement2 not in updated_group:
updated_group.append(measurement2)
r = v2 / v1 * scale
if debug:
print("ratio is: ", r)
v1 = None
v2 = None
sub_label = None
speedup_task = common.TaskSpec(
"",
setup="",
label=measurement.label,
sub_label=measurement.sub_label,
num_threads=measurement.num_threads,
env=measurement.env,
description=description
)
speedup_measurement = Value(1, [r, ], speedup_task)
r = None
updated_group.append(speedup_measurement)
assert len(updated_group) > len(group), "Seems like nothing was added. Run with --debug"
table = CustomizedTable(
updated_group,
compare._colorize,
compare._trim_significant_figures,
compare._highlight_warnings
)
return table
def main(
output_filepath: str,
perf_files: List[str],
*,
col1: str,
col2: str,
description: Optional[str] = None,
debug: bool = False
):
output_filepath = Path(output_filepath)
if output_filepath.exists():
raise FileExistsError(f"Output file '{output_filepath}' exists. Please provide a path to non-existing file")
if debug:
print("output_filepath:", output_filepath)
print("perf_files:", perf_files, type(perf_files))
print("col1:", col1, type(col1))
print("col2:", col2, type(col2))
print("description:", description, type(description))
ab_results = []
ab_configs = []
for perf_filepath in perf_files:
assert Path(perf_filepath).exists(), f"{perf_filepath} is not found"
with open(perf_filepath, "rb") as handler:
output = pickle.load(handler)
ab_configs.append(
f"Torch version: {output['torch_version']}\n"
f"Torch config: {output['torch_config']}\n"
)
ab_results.extend(output["test_results"])
assert len(ab_configs) == len(perf_files), (len(ab_configs), len(perf_files))
compare = benchmark.Compare(ab_results)
table = get_new_table(compare, col1=col1, col2=col2, description=description, debug=debug)
if debug:
print(table.render())
with output_filepath.open("w") as handler:
handler.write(f"Description:\n")
with unittest.mock.patch(
"torch.utils.benchmark.utils.compare._Row.as_column_strings", patched_as_column_strings
):
for in_filepath, config in zip(perf_files, ab_configs):
handler.write(f"- {Path(in_filepath).stem}\n")
handler.write(f"{config}\n")
handler.write(f"\n")
handler.write(table.render())
if __name__ == "__main__":
fire.Fire(main)
import pickle
from pathlib import Path
import unittest.mock
import numpy as np
import PIL.Image
import torch
import torch.utils.benchmark as benchmark
import fire
from torchvision_functional_tensor import resize
def pth_downsample_i8(img, mode, size, aa=True):
align_corners = False
if mode == "nearest":
align_corners = None
out = torch.nn.functional.interpolate(
img, size=size,
mode=mode,
align_corners=align_corners,
antialias=aa,
)
return out
def torchvision_resize(img, mode, size, aa=True):
return resize(img, size=size, interpolation=mode, antialias=aa)
if not hasattr(PIL.Image, "Resampling"):
resampling_map = {
"bilinear": PIL.Image.BILINEAR,
"nearest": PIL.Image.NEAREST,
"bicubic": PIL.Image.BICUBIC,
}
else:
resampling_map = {
"bilinear": PIL.Image.Resampling.BILINEAR,
"nearest": PIL.Image.Resampling.NEAREST,
"bicubic": PIL.Image.Resampling.BICUBIC,
}
def patched_as_column_strings(self):
concrete_results = [r for r in self._results if r is not None]
env = f"({concrete_results[0].env})" if self._render_env else ""
env = env.ljust(self._env_str_len + 4)
output = [" " + env + concrete_results[0].as_row_name]
for m, col in zip(self._results, self._columns or ()):
if m is None:
output.append(col.num_to_str(None, 1, None))
else:
if len(m.times) == 1:
spread = 0
else:
spread = float(torch.tensor(m.times, dtype=torch.float64).std(unbiased=len(m.times) > 1))
if col._trim_significant_figures:
spread = benchmark.utils.common.trim_sigfig(spread, m.significant_figures)
output.append(f"{m.median / self._time_scale:>3.3f} (+-{spread / self._time_scale:>3.3f})")
return output
def run_benchmark(c, dtype, size, osize, aa, mode, mf="channels_first", min_run_time=10, tag="", with_torchvision=False):
results = []
torch.manual_seed(12)
if dtype == torch.bool:
tensor = torch.randint(0, 2, size=(c, size[0], size[1]), dtype=dtype)
elif dtype == torch.complex64:
real = torch.randint(0, 256, size=(c, size[0], size[1]), dtype=torch.float32)
imag = torch.randint(0, 256, size=(c, size[0], size[1]), dtype=torch.float32)
tensor = torch.complex(real, imag)
elif dtype == torch.int8:
tensor = torch.randint(-127, 127, size=(c, size[0], size[1]), dtype=dtype)
else:
tensor = torch.randint(0, 256, size=(c, size[0], size[1]), dtype=dtype)
expected_pil = None
pil_img = None
if dtype == torch.uint8 and c == 3 and aa:
np_array = tensor.clone().permute(1, 2, 0).contiguous().numpy()
pil_img = PIL.Image.fromarray(np_array)
output_pil_img = pil_img.resize(osize[::-1], resample=resampling_map[mode])
expected_pil = torch.from_numpy(np.asarray(output_pil_img)).clone().permute(2, 0, 1).contiguous()
memory_format = torch.channels_last if mf == "channels_last" else torch.contiguous_format
tensor = tensor[None, ...].contiguous(memory_format=memory_format)
output = pth_downsample_i8(tensor, mode=mode, size=osize, aa=aa)
output = output[0, ...]
if expected_pil is not None:
abs_diff = torch.abs(expected_pil.float() - output.float())
mae = torch.mean(abs_diff)
max_abs_err = torch.max(abs_diff)
if mode == "bilinear":
assert mae.item() < 1.0, mae.item()
assert max_abs_err.item() < 2.0 + 1e-5, max_abs_err.item()
else:
raise RuntimeError(f"Unsupported mode: {mode}")
# PIL
if pil_img is not None:
results.append(
benchmark.Timer(
# pil_img = pil_img.resize((osize, osize), resample=resampling_map[mode])
stmt=f"data.resize({osize[::-1]}, resample=resample_val)",
globals={
"data": pil_img,
"resample_val": resampling_map[mode],
},
num_threads=torch.get_num_threads(),
label="Resize",
sub_label=f"{c} {dtype} {mf} {mode} {size} -> {osize} aa={aa}",
description=f"Pillow ({PIL.__version__})",
).blocked_autorange(min_run_time=min_run_time)
)
# Tensor interp
results.append(
benchmark.Timer(
# output = pth_downsample_i8(tensor, mode=mode, size=(osize, osize), aa=aa)
stmt=f"fn(data, mode='{mode}', size={osize}, aa={aa})",
globals={
"data": tensor,
"fn": pth_downsample_i8
},
num_threads=torch.get_num_threads(),
label="Resize",
sub_label=f"{c} {dtype} {mf} {mode} {size} -> {osize} aa={aa}",
description=f"torch ({torch.__version__}) {tag}",
).blocked_autorange(min_run_time=min_run_time)
)
# Torchvision resize
if with_torchvision:
results.append(
benchmark.Timer(
# output = torchvision_resize(tensor, mode=mode, size=(osize, osize), aa=aa)
stmt=f"fn(data, mode='{mode}', size={osize}, aa={aa})",
globals={
"data": tensor,
"fn": torchvision_resize
},
num_threads=torch.get_num_threads(),
label="Resize",
sub_label=f"{c} {dtype} {mf} {mode} {size} -> {osize} aa={aa}",
description=f"torchvision resize",
).blocked_autorange(min_run_time=min_run_time)
)
return results
def main(
output_filepath: str,
min_run_time: int = 10,
tag: str = "",
display: bool = True,
with_torchvision: bool = False,
extended_test_cases=True
):
output_filepath = Path(output_filepath)
test_results = []
for mf in ["channels_last", "channels_first"]:
for c, dtype in [
(3, torch.uint8),
(4, torch.uint8),
]:
for size in [256, 520, 712]:
if isinstance(size, int):
size = (size, size)
osize_aa_mode_list = [
(32, True, "bilinear"),
(32, False, "bilinear"),
(224, True, "bilinear"),
(224, False, "bilinear"),
]
if size == (256, 256):
osize_aa_mode_list += [
(320, True, "bilinear"),
(320, False, "bilinear"),
]
for osize, aa, mode in osize_aa_mode_list:
if isinstance(osize, int):
osize = (osize, osize)
test_results += run_benchmark(
c=c, dtype=dtype, size=size,
osize=osize, aa=aa, mode=mode, mf=mf,
min_run_time=min_run_time, tag=tag, with_torchvision=with_torchvision
)
if not extended_test_cases:
continue
for aa in [True, False]:
mode = "bilinear"
size_osize_list = [
(64, 224),
(224, (270, 268)),
(256, (1024, 1024)),
(224, 64),
((270, 268), 224),
(1024, 256),
]
for size, osize in size_osize_list:
if isinstance(size, int):
size = (size, size)
if isinstance(osize, int):
osize = (osize, osize)
test_results += run_benchmark(
c=c, dtype=dtype, size=size,
osize=osize, aa=aa, mode=mode, mf=mf,
min_run_time=min_run_time, tag=tag, with_torchvision=with_torchvision
)
with open(output_filepath, "wb") as handler:
output = {
"torch_version": torch.__version__,
"torch_config": torch.__config__.show(),
"num_threads": torch.get_num_threads(),
"pil_version": PIL.__version__,
"test_results": test_results,
}
pickle.dump(output, handler)
if display:
with unittest.mock.patch(
"torch.utils.benchmark.utils.compare._Row.as_column_strings", patched_as_column_strings
):
compare = benchmark.Compare(test_results)
compare.print()
if __name__ == "__main__":
torch.set_num_threads(1)
from datetime import datetime
print(f"Timestamp: {datetime.now().strftime('%Y%m%d-%H%M%S')}")
print(f"Torch version: {torch.__version__}")
print(f"Torch config: {torch.__config__.show()}")
print(f"Num threads: {torch.get_num_threads()}")
print("")
print("PIL version: ", PIL.__version__)
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment