Created
December 12, 2018 20:43
-
-
Save ezyang/7cf5918e6a15dca74bab0ada4d7ca665 to your computer and use it in GitHub Desktop.
HIPified Welford
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #pragma once | |
| #include <c10/macros/Macros.h> | |
| #if (defined(__HIPCC__) || defined(__HIPCC__)) | |
| #include <THH/THHDeviceUtils.cuh> | |
| #include <ATen/native/hip/DeviceSqrt.cuh> | |
| #else | |
| #include <cmath> | |
| #define device_sqrt std::sqrt | |
| #endif | |
| namespace at { namespace native { | |
| struct WelfordData { | |
| double mean; | |
| double m2; | |
| int64_t n; | |
| C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0) {} | |
| C10_DEVICE WelfordData(double mean, double m2, int64_t n) : mean(mean), m2(m2), n(n) {} | |
| }; | |
| template <typename scalar_t> | |
| struct WelfordOps { | |
| bool unbiased; | |
| public: | |
| inline C10_DEVICE WelfordData reduce(WelfordData acc, scalar_t data) const { | |
| double delta = data - acc.mean; | |
| double new_mean = acc.mean + delta / (acc.n + 1); | |
| double new_delta = data - new_mean; | |
| return { | |
| new_mean, | |
| acc.m2 + delta * new_delta, | |
| acc.n + 1 | |
| }; | |
| } | |
| inline C10_DEVICE WelfordData combine(WelfordData a, WelfordData b) const { | |
| if (a.n == 0) { | |
| return b; | |
| } | |
| if (b.n == 0) { | |
| return a; | |
| } | |
| double delta = b.mean - a.mean; | |
| int64_t new_count = a.n + b.n; | |
| double nb_over_n = (double)b.n / new_count; | |
| return { | |
| a.mean + delta * nb_over_n, | |
| a.m2 + b.m2 + delta * delta * a.n * nb_over_n, | |
| new_count | |
| }; | |
| } | |
| inline C10_DEVICE scalar_t project(WelfordData acc) const { | |
| int64_t divisor = unbiased ? (acc.n - 1) : acc.n; | |
| return (divisor > 0) ? device_sqrt(acc.m2 / divisor) : NAN; | |
| } | |
| #if defined(__HIPCC__) || defined(__HIPCC__) | |
| inline __device__ WelfordData warp_shfl_down(WelfordData acc, int offset) const { | |
| return { | |
| WARP_SHFL_DOWN(acc.mean, offset) | |
| , WARP_SHFL_DOWN(acc.m2, offset) | |
| , WARP_SHFL_DOWN(acc.n, offset) | |
| }; | |
| } | |
| #endif | |
| WelfordOps(bool unbiased) : unbiased(unbiased) { | |
| } | |
| }; | |
| }} // namespace at::native |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment