Created
March 31, 2021 00:02
-
-
Save ailzhang/992728d6964afd86a42ea56c6e43835b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/fbcode/caffe2/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/fbcode/caffe2/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h | |
--- a/fbcode/caffe2/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h | |
+++ b/fbcode/caffe2/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h | |
@@ -49,7 +49,8 @@ | |
// it's a bit troublesome, because fastpath TLS access requires the type of | |
// the TLS in question to be zero-initialized, so you don't actually win | |
// anyting in that case. | |
- return (((ks | local.included_ | always_included) - local.excluded_) & key_mask); | |
+ // For the addtional XOR op, see note [TLS Initialization] | |
+ return (((ks | (local.included_ ^ c10::InplaceOrView_keyset) | always_included) - local.excluded_) & key_mask); | |
} | |
} | |
diff --git a/fbcode/caffe2/c10/core/DispatchKeySet.h b/fbcode/caffe2/c10/core/DispatchKeySet.h | |
--- a/fbcode/caffe2/c10/core/DispatchKeySet.h | |
+++ b/fbcode/caffe2/c10/core/DispatchKeySet.h | |
@@ -82,6 +82,10 @@ | |
DispatchKeySet operator-(DispatchKeySet other) const { | |
return DispatchKeySet(repr_ & ~other.repr_); | |
} | |
+ // Compute self ^ other | |
+ DispatchKeySet operator^(DispatchKeySet other) const { | |
+ return DispatchKeySet(repr_ ^ other.repr_); | |
+ } | |
// Perform set equality | |
bool operator==(DispatchKeySet other) const { | |
return repr_ == other.repr_; | |
@@ -203,6 +207,11 @@ | |
DispatchKey::AutogradOther, | |
}); | |
+// See Note [TLS Initialization] | |
+constexpr DispatchKeySet InplaceOrView_keyset = DispatchKeySet({ | |
+ DispatchKey::InplaceOrView, | |
+}); | |
+ | |
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView = | |
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView); | |
diff --git a/fbcode/caffe2/c10/core/InferenceMode.cpp b/fbcode/caffe2/c10/core/InferenceMode.cpp | |
--- a/fbcode/caffe2/c10/core/InferenceMode.cpp | |
+++ b/fbcode/caffe2/c10/core/InferenceMode.cpp | |
@@ -4,7 +4,7 @@ | |
namespace c10 { | |
bool InferenceMode::is_enabled() { | |
- return !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView); | |
+ return c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView); | |
} | |
} // namespace c10 | |
diff --git a/fbcode/caffe2/c10/core/InferenceMode.h b/fbcode/caffe2/c10/core/InferenceMode.h | |
--- a/fbcode/caffe2/c10/core/InferenceMode.h | |
+++ b/fbcode/caffe2/c10/core/InferenceMode.h | |
@@ -8,54 +8,59 @@ | |
// A RAII, thread local (!) guard that enables or disables inference mode upon | |
// construction, and sets it back to the original value upon destruction. | |
struct TORCH_API InferenceMode { | |
+ // Note [Expected TLS state in InferenceMode]: | |
+ // InferenceMode: InplaceOrView in included, Autograd in excluded | |
+ // NormalMode: InplaceOrView not in included, Autograd not in excluded | |
+ // | |
+ // Invariant: | |
+ // - InplaceOrView is never in the excluded set | |
+ // - Autograd is never in the included set | |
+ // | |
+ // 1. Why do we put InplaceOrView in included set InferenceMode? | |
+ // The behavior we actually want: InplaceOrView is in included set by default | |
+ // and remove it from included set in InferenceMode. | |
+ // But TLS can only be zero initialized (see Note [TLS Initialization]), | |
+ // we're adding InplaceOrView in InferenceMode so that we get the desired | |
+ // behavior after the XOR in DispatchKeyExtractor. | |
+ // | |
+ // For example: | |
+ // torch::Tensor a; | |
+ // { | |
+ // c10::InferenceMode guard(true); | |
+ // torch::Tensor in = torch::ones({2, 2}); | |
+ // a = in.view({1, 4}); | |
+ // } | |
+ // torch::Tensor c = a.view({4, 1}); // (*) | |
+ // If we don't add InplaceOrView to included set, (*) will skip its as_view | |
+ // setup entirely, `c` will be a Tensor that is not from Inference mode | |
+ // but has potentially wrong view metadata which should be forbidden.. | |
+ // By going through InplaceOrView kernel, we can throw an error since it | |
+ // broke our invariant: "Autograd keys must be in excluded set before | |
+ // reaching InplaceOrView kernel". | |
+ // | |
+ // 2. Why not put InplaceOrView in the excluded set inside InferenceMode? | |
+ // | |
+ // For example: | |
+ // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true); | |
+ // torch::Tensor k = a + 2; | |
+ // { | |
+ // c10::InferenceMode guard(true); | |
+ // k.add_(2); | |
+ // } | |
+ // `k.add_(2)` still need to go through InplaceOrView kernel so that it's | |
+ // prepared for future autograd. | |
InferenceMode(bool enabled=true): prev_keyset(c10::impl::tls_local_dispatch_key_set()) { | |
- // Note [Expected TLS state in InferenceMode]: | |
- // InferenceMode: InplaceOrView not in included, Autograd in excluded | |
- // NormalMode: InplaceOrView in included, Autograd not in excluded | |
- // | |
- // Invariant: | |
- // - InplaceOrView is never in the excluded set | |
- // - Autograd is never in the included set | |
- // | |
- // 1. Why not put InplaceOrView in the excluded set inside InferenceMode? | |
- // | |
- // For example: | |
- // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true); | |
- // torch::Tensor k = a + 2; | |
- // { | |
- // c10::InferenceMode guard(true); | |
- // k.add_(2); | |
- // } | |
- // `k.add_(2)` still need to go through InplaceOrView kernel so that it's | |
- // prepared for future autograd. | |
- // 2. Why do we need InplaceOrView in included set outside InferenceMode? | |
- // | |
- // For example: | |
- // torch::Tensor a; | |
- // { | |
- // c10::InferenceMode guard(true); | |
- // torch::Tensor in = torch::ones({2, 2}); | |
- // a = in.view({1, 4}); | |
- // } | |
- // torch::Tensor c = a.view({4, 1}); // (*) | |
- // If we don't add InplaceOrView to included set, (*) will skip its as_view | |
- // setup entirely, `c` will be a Tensor that is not from Inference mode | |
- // but has potentially wrong view metadata which should be forbidden.. | |
- // By going through InplaceOrView kernel, we can throw an error since it | |
- // broke our invariant: "Autograd keys must be in excluded set before | |
- // reaching InplaceOrView kernel". | |
- | |
- DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView) | |
- : prev_keyset.included_.add(c10::DispatchKey::InplaceOrView); | |
+ DispatchKeySet included = enabled ? prev_keyset.included_.add(c10::DispatchKey::InplaceOrView) | |
+ : prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView); | |
DispatchKeySet excluded = enabled ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset) | |
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset); | |
c10::impl::PODLocalDispatchKeySet cur_keyset {included.raw_repr(), excluded.raw_repr()}; | |
c10::impl::_force_tls_local_dispatch_key_set(cur_keyset); | |
} | |
+ | |
~InferenceMode() { | |
c10::impl::_force_tls_local_dispatch_key_set(prev_keyset); | |
} | |
- | |
static bool is_enabled(); | |
private: | |
diff --git a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp | |
--- a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp | |
+++ b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp | |
@@ -5,7 +5,13 @@ | |
namespace c10 { | |
namespace impl { | |
-thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set {DispatchKeySet(DispatchKey::InplaceOrView).raw_repr()}; | |
+// NB: POD, must be zero initialized! | |
+// Note [TLS Initialization] | |
+// We want raw_local_dispatch_key_set to be initialized with InplaceOrView key | |
+// in the included set. But certain Windows compiler (e.g the one used in ARVR tests) | |
+// only allow TLS to be zero-initialized. So we're working around the problem by adding | |
+// XOR raw_local_dispatch_key_set.included() with InplaceOrView in DispatchKeyExtractor.h. | |
+thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; | |
#if defined(_MSC_VER) || defined(C10_ANDROID) | |
LocalDispatchKeySet tls_local_dispatch_key_set() { | |
diff --git a/fbcode/caffe2/test/cpp/api/inference_mode.cpp b/fbcode/caffe2/test/cpp/api/inference_mode.cpp | |
--- a/fbcode/caffe2/test/cpp/api/inference_mode.cpp | |
+++ b/fbcode/caffe2/test/cpp/api/inference_mode.cpp | |
@@ -39,7 +39,7 @@ | |
ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(c10::DispatchKey::InplaceOrView)); | |
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset)); | |
ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode); | |
- ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::InplaceOrView), !inference_mode); | |
+ ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::InplaceOrView), inference_mode); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment