Skip to content

Instantly share code, notes, and snippets.

@ailzhang
Created March 31, 2021 16:03
Show Gist options
  • Save ailzhang/86abc3b72da4b2ec598a983fa52b7476 to your computer and use it in GitHub Desktop.
Save ailzhang/86abc3b72da4b2ec598a983fa52b7476 to your computer and use it in GitHub Desktop.
diff --git a/fbcode/caffe2/c10/core/DispatchKeySet.h b/fbcode/caffe2/c10/core/DispatchKeySet.h
--- a/fbcode/caffe2/c10/core/DispatchKeySet.h
+++ b/fbcode/caffe2/c10/core/DispatchKeySet.h
@@ -82,6 +82,10 @@
DispatchKeySet operator-(DispatchKeySet other) const {
return DispatchKeySet(repr_ & ~other.repr_);
}
+ // Compute self ^ other
+ DispatchKeySet operator^(DispatchKeySet other) const {
+ return DispatchKeySet(repr_ ^ other.repr_);
+ }
// Perform set equality
bool operator==(DispatchKeySet other) const {
return repr_ == other.repr_;
@@ -203,6 +207,11 @@
DispatchKey::AutogradOther,
});
+// See Note [TLS Initialization]
+constexpr DispatchKeySet default_included_set = DispatchKeySet({
+ DispatchKey::InplaceOrView,
+});
+
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView =
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView);
diff --git a/fbcode/caffe2/c10/core/InferenceMode.cpp b/fbcode/caffe2/c10/core/InferenceMode.cpp
--- a/fbcode/caffe2/c10/core/InferenceMode.cpp
+++ b/fbcode/caffe2/c10/core/InferenceMode.cpp
@@ -4,6 +4,7 @@
namespace c10 {
bool InferenceMode::is_enabled() {
+ // See Note [Expected TLS state in InferenceMode]
return !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView);
}
diff --git a/fbcode/caffe2/c10/core/InferenceMode.h b/fbcode/caffe2/c10/core/InferenceMode.h
--- a/fbcode/caffe2/c10/core/InferenceMode.h
+++ b/fbcode/caffe2/c10/core/InferenceMode.h
@@ -8,58 +8,61 @@
// A RAII, thread local (!) guard that enables or disables inference mode upon
// construction, and sets it back to the original value upon destruction.
struct TORCH_API InferenceMode {
+ // Note [Expected TLS state in InferenceMode]:
+ // InferenceMode: InplaceOrView not in raw_local_dispatch_key_set.included(),
+ // Autograd in raw_local_dispatch_key_set.excluded()
+ // NormalMode: InplaceOrView in raw_local_dispatch_key_set.included(),
+ // Autograd not in raw_local_dispatch_key_set.excluded()
+ //
+ // Invariant:
+ // - InplaceOrView is never in the excluded set
+ // - Autograd is never in the included set
+ //
+ // 1. Why do we put InplaceOrView in included set outside InferenceMode?
+ //
+ // For example:
+ // torch::Tensor a;
+ // {
+ // c10::InferenceMode guard(true);
+ // torch::Tensor in = torch::ones({2, 2});
+ // a = in.view({1, 4});
+ // }
+ // torch::Tensor c = a.view({4, 1}); // (*)
+ // If we don't add InplaceOrView to included set, (*) will skip its as_view
+ // setup entirely, `c` will be a Tensor that is not from Inference mode
+ // but has potentially wrong view metadata which should be forbidden..
+ // By going through InplaceOrView kernel, we can throw an error since it
+ // broke our invariant: "Autograd keys must be in excluded set before
+ // reaching InplaceOrView kernel".
+ //
+ // 2. Why not put InplaceOrView in the excluded set inside InferenceMode?
+ //
+ // For example:
+ // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true);
+ // torch::Tensor k = a + 2;
+ // {
+ // c10::InferenceMode guard(true);
+ // k.add_(2);
+ // }
+ // `k.add_(2)` still need to go through InplaceOrView kernel so that it's
+ // prepared for future autograd.
InferenceMode(bool enabled=true): prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
- // Note [Expected TLS state in InferenceMode]:
- // InferenceMode: InplaceOrView not in included, Autograd in excluded
- // NormalMode: InplaceOrView in included, Autograd not in excluded
- //
- // Invariant:
- // - InplaceOrView is never in the excluded set
- // - Autograd is never in the included set
- //
- // 1. Why not put InplaceOrView in the excluded set inside InferenceMode?
- //
- // For example:
- // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true);
- // torch::Tensor k = a + 2;
- // {
- // c10::InferenceMode guard(true);
- // k.add_(2);
- // }
- // `k.add_(2)` still need to go through InplaceOrView kernel so that it's
- // prepared for future autograd.
- // 2. Why do we need InplaceOrView in included set outside InferenceMode?
- //
- // For example:
- // torch::Tensor a;
- // {
- // c10::InferenceMode guard(true);
- // torch::Tensor in = torch::ones({2, 2});
- // a = in.view({1, 4});
- // }
- // torch::Tensor c = a.view({4, 1}); // (*)
- // If we don't add InplaceOrView to included set, (*) will skip its as_view
- // setup entirely, `c` will be a Tensor that is not from Inference mode
- // but has potentially wrong view metadata which should be forbidden..
- // By going through InplaceOrView kernel, we can throw an error since it
- // broke our invariant: "Autograd keys must be in excluded set before
- // reaching InplaceOrView kernel".
-
DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView)
: prev_keyset.included_.add(c10::DispatchKey::InplaceOrView);
DispatchKeySet excluded = enabled ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
- c10::impl::PODLocalDispatchKeySet cur_keyset {included.raw_repr(), excluded.raw_repr()};
+ c10::impl::PODLocalDispatchKeySet cur_keyset;
+ cur_keyset.set_included(included);
+ cur_keyset.set_excluded(excluded);
c10::impl::_force_tls_local_dispatch_key_set(cur_keyset);
}
+
~InferenceMode() {
c10::impl::_force_tls_local_dispatch_key_set(prev_keyset);
}
-
static bool is_enabled();
private:
c10::impl::LocalDispatchKeySet prev_keyset;
};
} // namespace c10
-
diff --git a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp
--- a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp
+++ b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp
@@ -5,7 +5,16 @@
namespace c10 {
namespace impl {
-thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set {DispatchKeySet(DispatchKey::InplaceOrView).raw_repr()};
+// NB: POD, must be zero initialized!
+// Note [TLS Initialization]
+// We wanted raw_local_dispatch_key_set to be initialized with non-zero state
+// e.g. InplaceOrView in included set. But certain Windows compiler (e.g the one
+// used in ARVR tests) only allow TLS to be zero-initialized.
+// To preserve the invariant that raw TLS storage of the default state is zero,
+// we obtain the actual include keyset by XORing raw_local_dispatch_key_set.included_
+// with c10::default_included_set. This logic is encapsulated in struct
+// PODLocalDispatchKeySet.
+thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
#if defined(_MSC_VER) || defined(C10_ANDROID)
LocalDispatchKeySet tls_local_dispatch_key_set() {
@@ -14,10 +23,8 @@
#endif // defined(_MSC_VER) || defined(C10_ANDROID)
void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) {
- raw_local_dispatch_key_set = PODLocalDispatchKeySet {
- key_set.included_.raw_repr(),
- key_set.excluded_.raw_repr()
- };
+ raw_local_dispatch_key_set.set_included(key_set.included_);
+ raw_local_dispatch_key_set.set_excluded(key_set.excluded_);
}
// An RAII guard could snapshot and restore the entire state (entire DispatchKeySet) as
diff --git a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.h b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.h
--- a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.h
+++ b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.h
@@ -26,19 +26,24 @@
// POD version of LocalDispatchKeySet. Declared here just so that
// we can put it in the guards.
+// This struct encapsulates special handling for TLS initialization
+// in set_included()/included() API so that they reflect the truth.
+// If you want to create PODLocalDispatchKeySet with non-zero state,
+// use set_included() instead of default constructor.
struct C10_API PODLocalDispatchKeySet {
uint64_t included_;
uint64_t excluded_;
+ // See Note [TLS Initialization]
DispatchKeySet included() const {
- return DispatchKeySet(DispatchKeySet::RAW, included_);
+ return DispatchKeySet(DispatchKeySet::RAW, included_) ^ c10::default_included_set;
}
DispatchKeySet excluded() const {
return DispatchKeySet(DispatchKeySet::RAW, excluded_);
}
void set_included(DispatchKeySet x) {
- included_ = x.raw_repr();
+ included_ = (x ^ c10::default_included_set).raw_repr();
}
void set_excluded(DispatchKeySet x) {
excluded_ = x.raw_repr();
diff --git a/fbcode/caffe2/test/cpp/api/grad_mode.cpp b/fbcode/caffe2/test/cpp/api/grad_mode.cpp
--- a/fbcode/caffe2/test/cpp/api/grad_mode.cpp
+++ b/fbcode/caffe2/test/cpp/api/grad_mode.cpp
@@ -66,5 +66,3 @@
assert_tensor_creation_meta(tmp, torch::autograd::CreationMeta::NO_GRAD_MODE);
}
}
-
-
diff --git a/xplat/caffe2/c10/core/DispatchKeySet.h b/xplat/caffe2/c10/core/DispatchKeySet.h
--- a/xplat/caffe2/c10/core/DispatchKeySet.h
+++ b/xplat/caffe2/c10/core/DispatchKeySet.h
@@ -82,6 +82,10 @@
DispatchKeySet operator-(DispatchKeySet other) const {
return DispatchKeySet(repr_ & ~other.repr_);
}
+ // Compute self ^ other
+ DispatchKeySet operator^(DispatchKeySet other) const {
+ return DispatchKeySet(repr_ ^ other.repr_);
+ }
// Perform set equality
bool operator==(DispatchKeySet other) const {
return repr_ == other.repr_;
@@ -203,6 +207,11 @@
DispatchKey::AutogradOther,
});
+// See Note [TLS Initialization]
+constexpr DispatchKeySet default_included_set = DispatchKeySet({
+ DispatchKey::InplaceOrView,
+});
+
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView =
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment