Skip to content

Instantly share code, notes, and snippets.

@ailzhang
Created March 31, 2021 00:02
Show Gist options
  • Save ailzhang/992728d6964afd86a42ea56c6e43835b to your computer and use it in GitHub Desktop.
Save ailzhang/992728d6964afd86a42ea56c6e43835b to your computer and use it in GitHub Desktop.
diff --git a/fbcode/caffe2/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/fbcode/caffe2/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
--- a/fbcode/caffe2/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
+++ b/fbcode/caffe2/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h
@@ -49,7 +49,8 @@
// it's a bit troublesome, because fastpath TLS access requires the type of
// the TLS in question to be zero-initialized, so you don't actually win
// anyting in that case.
- return (((ks | local.included_ | always_included) - local.excluded_) & key_mask);
+ // For the addtional XOR op, see note [TLS Initialization]
+ return (((ks | (local.included_ ^ c10::InplaceOrView_keyset) | always_included) - local.excluded_) & key_mask);
}
}
diff --git a/fbcode/caffe2/c10/core/DispatchKeySet.h b/fbcode/caffe2/c10/core/DispatchKeySet.h
--- a/fbcode/caffe2/c10/core/DispatchKeySet.h
+++ b/fbcode/caffe2/c10/core/DispatchKeySet.h
@@ -82,6 +82,10 @@
DispatchKeySet operator-(DispatchKeySet other) const {
return DispatchKeySet(repr_ & ~other.repr_);
}
+ // Compute self ^ other
+ DispatchKeySet operator^(DispatchKeySet other) const {
+ return DispatchKeySet(repr_ ^ other.repr_);
+ }
// Perform set equality
bool operator==(DispatchKeySet other) const {
return repr_ == other.repr_;
@@ -203,6 +207,11 @@
DispatchKey::AutogradOther,
});
+// See Note [TLS Initialization]
+constexpr DispatchKeySet InplaceOrView_keyset = DispatchKeySet({
+ DispatchKey::InplaceOrView,
+});
+
constexpr DispatchKeySet autograd_dispatch_keyset_with_InplaceOrView =
autograd_dispatch_keyset | DispatchKeySet(DispatchKey::InplaceOrView);
diff --git a/fbcode/caffe2/c10/core/InferenceMode.cpp b/fbcode/caffe2/c10/core/InferenceMode.cpp
--- a/fbcode/caffe2/c10/core/InferenceMode.cpp
+++ b/fbcode/caffe2/c10/core/InferenceMode.cpp
@@ -4,7 +4,7 @@
namespace c10 {
bool InferenceMode::is_enabled() {
- return !c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView);
+ return c10::impl::tls_is_dispatch_key_included(DispatchKey::InplaceOrView);
}
} // namespace c10
diff --git a/fbcode/caffe2/c10/core/InferenceMode.h b/fbcode/caffe2/c10/core/InferenceMode.h
--- a/fbcode/caffe2/c10/core/InferenceMode.h
+++ b/fbcode/caffe2/c10/core/InferenceMode.h
@@ -8,54 +8,59 @@
// A RAII, thread local (!) guard that enables or disables inference mode upon
// construction, and sets it back to the original value upon destruction.
struct TORCH_API InferenceMode {
+ // Note [Expected TLS state in InferenceMode]:
+ // InferenceMode: InplaceOrView in included, Autograd in excluded
+ // NormalMode: InplaceOrView not in included, Autograd not in excluded
+ //
+ // Invariant:
+ // - InplaceOrView is never in the excluded set
+ // - Autograd is never in the included set
+ //
+ // 1. Why do we put InplaceOrView in included set InferenceMode?
+ // The behavior we actually want: InplaceOrView is in included set by default
+ // and remove it from included set in InferenceMode.
+ // But TLS can only be zero initialized (see Note [TLS Initialization]),
+ // we're adding InplaceOrView in InferenceMode so that we get the desired
+ // behavior after the XOR in DispatchKeyExtractor.
+ //
+ // For example:
+ // torch::Tensor a;
+ // {
+ // c10::InferenceMode guard(true);
+ // torch::Tensor in = torch::ones({2, 2});
+ // a = in.view({1, 4});
+ // }
+ // torch::Tensor c = a.view({4, 1}); // (*)
+ // If we don't add InplaceOrView to included set, (*) will skip its as_view
+ // setup entirely, `c` will be a Tensor that is not from Inference mode
+ // but has potentially wrong view metadata which should be forbidden..
+ // By going through InplaceOrView kernel, we can throw an error since it
+ // broke our invariant: "Autograd keys must be in excluded set before
+ // reaching InplaceOrView kernel".
+ //
+ // 2. Why not put InplaceOrView in the excluded set inside InferenceMode?
+ //
+ // For example:
+ // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true);
+ // torch::Tensor k = a + 2;
+ // {
+ // c10::InferenceMode guard(true);
+ // k.add_(2);
+ // }
+ // `k.add_(2)` still need to go through InplaceOrView kernel so that it's
+ // prepared for future autograd.
InferenceMode(bool enabled=true): prev_keyset(c10::impl::tls_local_dispatch_key_set()) {
- // Note [Expected TLS state in InferenceMode]:
- // InferenceMode: InplaceOrView not in included, Autograd in excluded
- // NormalMode: InplaceOrView in included, Autograd not in excluded
- //
- // Invariant:
- // - InplaceOrView is never in the excluded set
- // - Autograd is never in the included set
- //
- // 1. Why not put InplaceOrView in the excluded set inside InferenceMode?
- //
- // For example:
- // torch::Tensor a = torch::ones({1, 2, 3}).set_requires_grad(true);
- // torch::Tensor k = a + 2;
- // {
- // c10::InferenceMode guard(true);
- // k.add_(2);
- // }
- // `k.add_(2)` still need to go through InplaceOrView kernel so that it's
- // prepared for future autograd.
- // 2. Why do we need InplaceOrView in included set outside InferenceMode?
- //
- // For example:
- // torch::Tensor a;
- // {
- // c10::InferenceMode guard(true);
- // torch::Tensor in = torch::ones({2, 2});
- // a = in.view({1, 4});
- // }
- // torch::Tensor c = a.view({4, 1}); // (*)
- // If we don't add InplaceOrView to included set, (*) will skip its as_view
- // setup entirely, `c` will be a Tensor that is not from Inference mode
- // but has potentially wrong view metadata which should be forbidden..
- // By going through InplaceOrView kernel, we can throw an error since it
- // broke our invariant: "Autograd keys must be in excluded set before
- // reaching InplaceOrView kernel".
-
- DispatchKeySet included = enabled ? prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView)
- : prev_keyset.included_.add(c10::DispatchKey::InplaceOrView);
+ DispatchKeySet included = enabled ? prev_keyset.included_.add(c10::DispatchKey::InplaceOrView)
+ : prev_keyset.included_.remove(c10::DispatchKey::InplaceOrView);
DispatchKeySet excluded = enabled ? (prev_keyset.excluded_ | c10::autograd_dispatch_keyset)
: (prev_keyset.excluded_ - c10::autograd_dispatch_keyset);
c10::impl::PODLocalDispatchKeySet cur_keyset {included.raw_repr(), excluded.raw_repr()};
c10::impl::_force_tls_local_dispatch_key_set(cur_keyset);
}
+
~InferenceMode() {
c10::impl::_force_tls_local_dispatch_key_set(prev_keyset);
}
-
static bool is_enabled();
private:
diff --git a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp
--- a/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp
+++ b/fbcode/caffe2/c10/core/impl/LocalDispatchKeySet.cpp
@@ -5,7 +5,13 @@
namespace c10 {
namespace impl {
-thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set {DispatchKeySet(DispatchKey::InplaceOrView).raw_repr()};
+// NB: POD, must be zero initialized!
+// Note [TLS Initialization]
+// We want raw_local_dispatch_key_set to be initialized with InplaceOrView key
+// in the included set. But certain Windows compiler (e.g the one used in ARVR tests)
+// only allow TLS to be zero-initialized. So we're working around the problem by adding
+// XOR raw_local_dispatch_key_set.included() with InplaceOrView in DispatchKeyExtractor.h.
+thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
#if defined(_MSC_VER) || defined(C10_ANDROID)
LocalDispatchKeySet tls_local_dispatch_key_set() {
diff --git a/fbcode/caffe2/test/cpp/api/inference_mode.cpp b/fbcode/caffe2/test/cpp/api/inference_mode.cpp
--- a/fbcode/caffe2/test/cpp/api/inference_mode.cpp
+++ b/fbcode/caffe2/test/cpp/api/inference_mode.cpp
@@ -39,7 +39,7 @@
ASSERT_FALSE(c10::impl::tls_is_dispatch_key_excluded(c10::DispatchKey::InplaceOrView));
ASSERT_FALSE(c10::impl::tls_is_dispatch_keyset_included(c10::autograd_dispatch_keyset));
ASSERT_EQ(c10::impl::tls_is_dispatch_keyset_excluded(c10::autograd_dispatch_keyset), inference_mode);
- ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::InplaceOrView), !inference_mode);
+ ASSERT_EQ(c10::impl::tls_is_dispatch_key_included(c10::DispatchKey::InplaceOrView), inference_mode);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment