Created
March 31, 2021 16:03
-
-
Save ailzhang/86abc3b72da4b2ec598a983fa52b7476 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/fbcode/caffe2/c10/core/DispatchKeySet.h b/fbcode/caffe2/c10/core/DispatchKeySet.h | |
--- a/fbcode/caffe2/c10/core/DispatchKeySet.h | |
+++ b/fbcode/caffe2/c10/core/DispatchKeySet.h | |
@@ -82,6 +82,10 @@ | |
DispatchKeySet operator-(DispatchKeySet other) const { | |
return DispatchKeySet(repr_ & ~other.repr_); | |
} | |
+ // Compute self ^ other | |
+ DispatchKeySet operator^(DispatchKeySet other) const { | |
+ return DispatchKeySet(repr_ ^ other.repr_); | |
+ } | |
// Perform set equality | |
bool operator==(DispatchKeySet other) const { | |
return repr_ == other.repr_; | |
@@ -203,6 +207,11 @@ | |
DispatchKey::AutogradOther, | |
}); | |
+// See Note [TLS Initialization] | |
+constexpr DispatchKeySet default_included_set = DispatchKeySet({ | |
+ DispatchKey::InplaceOrView, | |
+}); | |
+ | |
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView = | |
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView); | |
diff --git a/fbcode/caffe2/c10/core/InferenceMode.cpp b/fbcode/caffe2/c10/core/InferenceMode.cpp | |
--- a/fbcode/caffe2/c10/core/InferenceMode.cpp | |
+++ b/fbcode/caffe2/c10/core/InferenceMode.cpp | |
@@ -4,6 +4,7 @@ | |
namespace c10 { | |
bool InferenceMode::is_enabled() { | |
+ // See Note [Expected TLS state in InferenceMode] | |
return !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView); | |
} | |
diff --git a/fbcode/caffe2/c10/core/InferenceMode.h b/fbcode/caffe2/c10/core/InferenceMode.h | |
--- a/fbcode/caffe2/c10/core/InferenceMode.h | |
+++ b/fbcode/caffe2/c10/core/InferenceMode.h | |
@@ -8,58 +8,61 @@ | |
// A RAII, thread local (!) guard that enables or disables inference mode upon | |
// construction, and sets it back to the original value upon destruction. | |
struct TORCH_API InferenceMode { | |
+ // Note [Expected TLS state in InferenceMode]: | |
+ // InferenceMode: InplaceOrView not in raw_local_dispatch_key_set.included(), | |
+ // Autograd in raw_local_dispatch_key_set.excluded() | |
+ // NormalMode: InplaceOrView in raw_local_dispatch_key_set.included(), | |
+ // Autograd not in raw_local_dispatch_key_set.excluded() | |
+ // | |
+ // Invariant: | |
+ // - InplaceOrView is never in the excluded set | |
+ // - Autograd is never in the included set | |
+ // | |
+ // 1. Why do we put InplaceOrView in included set outside InferenceMode? | |
+ // | |
+ // For example: | |
+ // torch::Tensor a; | |
+ // { | |
+ // c10::InferenceMode guard(true); | |
+ // torch::Tensor in = torch::ones({2, 2}); | |
+ // a = in.view({1, 4}); | |
+ // } | |
+ // torch::Tensor c = a.view({4, 1}); // (*) | |
+ // If we don't add InplaceOrView to included set, (*) will skip its as_view | |
+ // setup entirely, `c` will be a Tensor that is not from Inference mode | |
+ // but has potentially wrong view metadata which should be forbidden.. | |
+ // By going through InplaceOrView kernel, we can throw an error since it | |
+ // broke our invariant: "Autograd keys must be in excluded set before | |
+ // reaching InplaceOrView kernel". | |
+ // | |
+ // 2. Why not put InplaceOrView in the excluded set inside InferenceMode? | |
+ // | |
+ // For example: | |
+ // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true); | |
+ // torch::Tensor k = a + 2; | |
+ // { | |
+ // c10::InferenceMode guard(true); | |
+ // k.add_(2); | |
+ // } | |
+ // `k.add_(2)` still need to go through InplaceOrView kernel so that it's | |
+ // prepared for future autograd. | |
InferenceMode(bool enabled=true): prev_keyset(c10::impl::tls_local_dispatch_key_set()) { | |
- // Note [Expected TLS state in InferenceMode]: | |
- // InferenceMode: InplaceOrView not in included, Autograd in excluded | |
- // NormalMode: InplaceOrView in included, Autograd not in excluded | |
- // | |
- // Invariant: | |
- // - InplaceOrView is never in the excluded set | |
- // - Autograd is never in the included set | |
- // | |
- // 1. Why not put InplaceOrView in the excluded set inside InferenceMode? | |
- // | |
- // For example: | |
- // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true); | |
- // torch::Tensor k = a + 2; | |
- // { | |
- // c10::InferenceMode guard(true); | |
- // k.add_(2); | |
- // } | |
- // `k.add_(2)` still need to go through InplaceOrView kernel so that it's | |
- // prepared for future autograd. | |
- // 2. Why do we need InplaceOrView in included set outside InferenceMode? | |
- // | |
- // For example: | |
- // torch::Tensor a; | |
- // { | |
- // c10::InferenceMode guard(true); | |
- // torch::Tensor in = torch::ones({2, 2}); | |
- // a = in.view({1, 4}); | |
- // } | |
- // torch::Tensor c = a.view({4, 1}); // (*) | |
- // If we don't add InplaceOrView to included set, (*) will skip its as_view | |
- // setup entirely, `c` will be a Tensor that is not from Inference mode | |
- // but has potentially wrong view metadata which should be forbidden.. | |
- // By going through InplaceOrView kernel, we can throw an error since it | |
- // broke our invariant: "Autograd keys must be in excluded set before | |
- // reaching InplaceOrView kernel". | |
- | |
DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView) | |
: prev_keyset.included_.add(c10::DispatchKey::InplaceOrView); | |
DispatchKeySet excluded = enabled ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset) | |
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset); | |
- c10::impl::PODLocalDispatchKeySet cur_keyset {included.raw_repr(), excluded.raw_repr()}; | |
+ c10::impl::PODLocalDispatchKeySet cur_keyset; | |
+ cur_keyset.set_included(included); | |
+ cur_keyset.set_excluded(excluded); | |
c10::impl::_force_tls_local_dispatch_key_set(cur_keyset); | |
} | |
+ | |
~InferenceMode() { | |
c10::impl::_force_tls_local_dispatch_key_set(prev_keyset); | |
} | |
- | |
static bool is_enabled(); | |
private: | |
c10::impl::LocalDispatchKeySet prev_keyset; | |
}; | |
} // namespace c10 | |
- | |
diff --git a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp | |
--- a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp | |
+++ b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp | |
@@ -5,7 +5,16 @@ | |
namespace c10 { | |
namespace impl { | |
-thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set {DispatchKeySet(DispatchKey::InplaceOrView).raw_repr()}; | |
+// NB: POD, must be zero initialized! | |
+// Note [TLS Initialization] | |
+// We wanted raw_local_dispatch_key_set to be initialized with non-zero state | |
+// e.g. InplaceOrView in included set. But certain Windows compiler (e.g the one | |
+// used in ARVR tests) only allow TLS to be zero-initialized. | |
+// To preserve the invariant that raw TLS storage of the default state is zero, | |
+// we obtain the actual include keyset by XORing raw_local_dispatch_key_set.included_ | |
+// with c10::default_included_set. This logic is encapsulated in struct | |
+// PODLocalDispatchKeySet. | |
+thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; | |
#if defined(_MSC_VER) || defined(C10_ANDROID) | |
LocalDispatchKeySet tls_local_dispatch_key_set() { | |
@@ -14,10 +23,8 @@ | |
#endif // defined(_MSC_VER) || defined(C10_ANDROID) | |
void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) { | |
- raw_local_dispatch_key_set = PODLocalDispatchKeySet { | |
- key_set.included_.raw_repr(), | |
- key_set.excluded_.raw_repr() | |
- }; | |
+ raw_local_dispatch_key_set.set_included(key_set.included_); | |
+ raw_local_dispatch_key_set.set_excluded(key_set.excluded_); | |
} | |
// An RAII guard could snapshot and restore the entire state (entire DispatchKeySet) as | |
diff --git a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.h b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.h | |
--- a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.h | |
+++ b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.h | |
@@ -26,19 +26,24 @@ | |
// POD version of LocalDispatchKeySet. Declared here just so that | |
// we can put it in the guards. | |
+// This struct encapsulates special handling for TLS initialization | |
+// in set_included()/included() API so that they reflect the truth. | |
+// If you want to create PODLocalDispatchKeySet with non-zero state, | |
+// use set_included() instead of default constructor. | |
struct C10_API PODLocalDispatchKeySet { | |
uint64_t included_; | |
uint64_t excluded_; | |
+ // See Note [TLS Initialization] | |
DispatchKeySet included() const { | |
- return DispatchKeySet(DispatchKeySet::RAW, included_); | |
+ return DispatchKeySet(DispatchKeySet::RAW, included_) ^ c10::default_included_set; | |
} | |
DispatchKeySet excluded() const { | |
return DispatchKeySet(DispatchKeySet::RAW, excluded_); | |
} | |
void set_included(DispatchKeySet x) { | |
- included_ = x.raw_repr(); | |
+ included_ = (x ^ c10::default_included_set).raw_repr(); | |
} | |
void set_excluded(DispatchKeySet x) { | |
excluded_ = x.raw_repr(); | |
diff --git a/fbcode/caffe2/test/cpp/api/grad_mode.cpp b/fbcode/caffe2/test/cpp/api/grad_mode.cpp | |
--- a/fbcode/caffe2/test/cpp/api/grad_mode.cpp | |
+++ b/fbcode/caffe2/test/cpp/api/grad_mode.cpp | |
@@ -66,5 +66,3 @@ | |
assert_tensor_creation_meta(tmp, torch::autograd::CreationMeta::NO_GRAD_MODE); | |
} | |
} | |
- | |
- | |
diff --git a/xplat/caffe2/c10/core/DispatchKeySet.h b/xplat/caffe2/c10/core/DispatchKeySet.h | |
--- a/xplat/caffe2/c10/core/DispatchKeySet.h | |
+++ b/xplat/caffe2/c10/core/DispatchKeySet.h | |
@@ -82,6 +82,10 @@ | |
DispatchKeySet operator-(DispatchKeySet other) const { | |
return DispatchKeySet(repr_ & ~other.repr_); | |
} | |
+ // Compute self ^ other | |
+ DispatchKeySet operator^(DispatchKeySet other) const { | |
+ return DispatchKeySet(repr_ ^ other.repr_); | |
+ } | |
// Perform set equality | |
bool operator==(DispatchKeySet other) const { | |
return repr_ == other.repr_; | |
@@ -203,6 +207,11 @@ | |
DispatchKey::AutogradOther, | |
}); | |
+// See Note [TLS Initialization] | |
+constexpr DispatchKeySet default_included_set = DispatchKeySet({ | |
+ DispatchKey::InplaceOrView, | |
+}); | |
+ | |
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView = | |
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment