[mlir][python] Add debug helpers

This PR lets users provide --debug-only flags with python decorators.
This greatly simplifies the debugging experience in python.

Co-authored-by: Tres <tpopp@users.noreply.github.com>
diff --git a/llvm/include/llvm/Support/Debug.h b/llvm/include/llvm/Support/Debug.h
index 924d7b2..6720ef5 100644
--- a/llvm/include/llvm/Support/Debug.h
+++ b/llvm/include/llvm/Support/Debug.h
@@ -54,6 +54,28 @@
 ///
 void setCurrentDebugTypes(const char **Types, unsigned Count);
 
+/// appendDebugType - Set the current debug type, as if the
+/// -debug-only=CurX,CurY,NewZ option were specified when called with NewZ.
+/// The previous state is tracked, so popAppendedDebugTypes can be called to
+/// restore the previous state. Note that DebugFlag also needs to be set to true
+/// for debug output to be produced.
+///
+void appendDebugType(const char *Type);
+
+/// appendDebugTypes - Set the current debug type, as if the
+/// -debug-only=CurX,CurY,NewX,NewY option were specified when called with
+/// [NewX, NewY]. The previous state is tracked, so popAppendedDebugTypes can be
+/// called to restore the previous state. Note that DebugFlag also needs to be
+/// set to true debug output to be produced.
+///
+void appendDebugTypes(const char **Types, unsigned Count);
+
+/// popAppendedDebugTypes - Restores CurDebugType to the state before the last
+/// call to appendDebugType(s). Asserts and returns if the previous state was
+/// empty or was reset by setCurrentDebugType(s).
+///
+void popAppendedDebugTypes();
+
 /// DEBUG_WITH_TYPE macro - This macro should be used by passes to emit debug
 /// information.  If the '-debug' option is specified on the commandline, and if
 /// this is a debug build, then the code specified as the option to the macro
@@ -77,6 +99,19 @@
 #define DEBUG_WITH_TYPE(TYPE, ...)                                             \
   do {                                                                         \
   } while (false)
+#define appendDebugType(X)                                                     \
+  do {                                                                         \
+    (void)(X);                                                                 \
+  } while (false)
+#define appendDebugTypes(X, N)                                                 \
+  do {                                                                         \
+    (void)(X);                                                                 \
+    (void)(N);                                                                 \
+  } while (false)
+#define popAppendedDebugTypes()                                                \
+  do {                                                                         \
+    ;                                                                          \
+  } while (false)
 #endif
 
 /// This boolean is set to true if the '-debug' command line option
diff --git a/llvm/lib/Support/Debug.cpp b/llvm/lib/Support/Debug.cpp
index 5bb04d0..2582be7 100644
--- a/llvm/lib/Support/Debug.cpp
+++ b/llvm/lib/Support/Debug.cpp
@@ -35,6 +35,9 @@
 #undef isCurrentDebugType
 #undef setCurrentDebugType
 #undef setCurrentDebugTypes
+#undef appendDebugType
+#undef appendDebugTypes
+#undef popAppendedDebugTypes
 
 using namespace llvm;
 
@@ -45,6 +48,32 @@
 bool DebugFlag = false;
 
 static ManagedStatic<std::vector<std::string>> CurrentDebugType;
+static ManagedStatic<std::vector<int>> AppendedDebugTypeSizes;
+
+/// Appends to the CurrentDebugState by keeping its old state and adding the
+/// new state.
+void appendDebugTypes(const char **Types, unsigned Count) {
+  AppendedDebugTypeSizes->push_back(CurrentDebugType->size());
+  for (size_t T = 0; T < Count; ++T)
+    CurrentDebugType->push_back(Types[T]);
+}
+
+/// Appends to the CurrentDebugState by keeping its old state and adding the
+/// new state.
+void appendDebugType(const char *Type) { appendDebugTypes(&Type, 1); }
+
+/// Restore to the state before the latest call to appendDebugTypes. This can be
+/// done multiple times.
+void popAppendedDebugTypes() {
+  assert(AppendedDebugTypeSizes->size() > 0 &&
+         "Popping from DebugTypes without any previous appending.");
+
+  if (!AppendedDebugTypeSizes->size())
+    return;
+
+  CurrentDebugType->resize(AppendedDebugTypeSizes->back());
+  AppendedDebugTypeSizes->pop_back();
+}
 
 /// Return true if the specified string is the debug type
 /// specified on the command line, or if none was specified on the command line
@@ -72,6 +101,8 @@
 }
 
 void setCurrentDebugTypes(const char **Types, unsigned Count) {
+  assert(AppendedDebugTypeSizes->size() == 0 &&
+         "Resetting CurrentDebugType when it was previously appended to");
   CurrentDebugType->clear();
   llvm::append_range(*CurrentDebugType, ArrayRef(Types, Count));
 }
diff --git a/llvm/unittests/Support/DebugTest.cpp b/llvm/unittests/Support/DebugTest.cpp
index e8b7548..3874142 100644
--- a/llvm/unittests/Support/DebugTest.cpp
+++ b/llvm/unittests/Support/DebugTest.cpp
@@ -19,8 +19,8 @@
 TEST(DebugTest, Basic) {
   std::string s1, s2;
   raw_string_ostream os1(s1), os2(s2);
-  static const char *DT[] = {"A", "B"};  
-  
+  static const char *DT[] = {"A", "B"};
+
   llvm::DebugFlag = true;
   setCurrentDebugTypes(DT, 2);
   DEBUG_WITH_TYPE("A", os1 << "A");
@@ -50,4 +50,25 @@
   });
   EXPECT_EQ("ZYX", os1.str());
 }
+
+TEST(DebugTest, AppendAndPop) {
+  std::string s1, s2, s3;
+  raw_string_ostream os1(s1), os2(s2), os3(s3);
+
+  llvm::DebugFlag = true;
+  appendDebugType("A");
+  DEBUG_WITH_TYPE("A", os1 << "A");
+  DEBUG_WITH_TYPE("B", os1 << "B");
+  EXPECT_EQ("A", os1.str());
+
+  appendDebugType("B");
+  DEBUG_WITH_TYPE("A", os2 << "A");
+  DEBUG_WITH_TYPE("B", os2 << "B");
+  EXPECT_EQ("AB", os2.str());
+
+  popAppendedDebugTypes();
+  DEBUG_WITH_TYPE("A", os3 << "A");
+  DEBUG_WITH_TYPE("B", os3 << "B");
+  EXPECT_EQ("A", os3.str());
+}
 #endif
diff --git a/mlir/include/mlir-c/Debug.h b/mlir/include/mlir-c/Debug.h
index 7dad735..8f206e4 100644
--- a/mlir/include/mlir-c/Debug.h
+++ b/mlir/include/mlir-c/Debug.h
@@ -31,6 +31,24 @@
 /// output to be produced.
 MLIR_CAPI_EXPORTED void mlirSetGlobalDebugTypes(const char **types, intptr_t n);
 
+/// Adds to the current debug type state, similarly to
+/// `-debug-only=prev_type,new_type` in the command-line tools. Note that global
+/// debug should be enabled for any output to be produced. A single append call
+/// can be reverted with mlirPopAppendedGlobalDebugTypes.
+MLIR_CAPI_EXPORTED void mlirAppendGlobalDebugType(const char *type);
+
+/// Adds to the current debug type state, similarly to
+/// `-debug-only=prev_type,new_type1,new_type2` in the command-line tools. Note
+/// that global debug should be enabled for any output to be produced. A single
+/// append call can be reverted with mlirPopAppendedGlobalDebugTypes.
+MLIR_CAPI_EXPORTED void mlirAppendGlobalDebugTypes(const char **types,
+                                                   intptr_t n);
+
+/// Restores the current debug type state to its state before the last append
+/// call. An appended state of `-debug-only=prev_type,new_type1,new_type2` would
+/// be `-debug-only=prev_type` after this call.
+MLIR_CAPI_EXPORTED void mlirPopAppendedGlobalDebugTypes();
+
 /// Checks if `type` is set as the current debug type.
 MLIR_CAPI_EXPORTED bool mlirIsCurrentDebugType(const char *type);
 
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 002923b..6ce27d1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -282,7 +282,26 @@
             pointers.push_back(str.c_str());
           nb::ft_lock_guard lock(mutex);
           mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
-        });
+        })
+        .def_static(
+            "push_debug_only_flags",
+            [](const std::string &type) {
+              mlirAppendGlobalDebugType(type.c_str());
+            },
+            "flags"_a,
+            "Appends specific debug only flags which can be popped later.")
+        .def_static("push_debug_only_flags",
+                    [](const std::vector<std::string> &types) {
+                      std::vector<const char *> pointers;
+                      pointers.reserve(types.size());
+                      for (const std::string &str : types)
+                        pointers.push_back(str.c_str());
+                      mlirAppendGlobalDebugTypes(pointers.data(),
+                                                 pointers.size());
+                    })
+        .def_static(
+            "pop_debug_only_flags", []() { mlirPopAppendedGlobalDebugTypes(); },
+            "Removes the latest non-popped addition from push_debug_only_flags.");
   }
 
 private:
diff --git a/mlir/lib/CAPI/Debug/Debug.cpp b/mlir/lib/CAPI/Debug/Debug.cpp
index 320ece4..c76a109 100644
--- a/mlir/lib/CAPI/Debug/Debug.cpp
+++ b/mlir/lib/CAPI/Debug/Debug.cpp
@@ -34,3 +34,18 @@
   using namespace llvm;
   return isCurrentDebugType(type);
 }
+
+void mlirAppendGlobalDebugType(const char *type) {
+  using namespace llvm;
+  appendDebugType(type);
+}
+
+void mlirAppendGlobalDebugTypes(const char **types, intptr_t n) {
+  using namespace llvm;
+  appendDebugTypes(types, n);
+}
+
+void mlirPopAppendedGlobalDebugTypes() {
+  using namespace llvm;
+  popAppendedDebugTypes();
+}
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 56b9f17..ffba1e3 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -2818,3 +2818,5 @@
 
 class _GlobalDebug:
     flag: ClassVar[bool] = False
+    def push_debug_only_flags(self, types: list[str]) -> None: ...
+    def pop_debug_only_flags(self) -> None: ...
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
index 0d2eaff..d23da98 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi
@@ -12,9 +12,10 @@
 ]
 
 class PassManager:
-    def __init__(self, context: _ir.Context | None = None) -> None: ...
+    def __init__(self, anchor_op: str | None = None) -> None: ...
     def _CAPICreate(self) -> object: ...
     def _testing_release(self) -> None: ...
+    def add(self, pipeline: str) -> None: ...
     def enable_ir_printing(
         self,
         print_before_all: bool = False,
diff --git a/mlir/python/mlir/utils.py b/mlir/python/mlir/utils.py
index c6e9b57..cf8ad22 100644
--- a/mlir/python/mlir/utils.py
+++ b/mlir/python/mlir/utils.py
@@ -209,3 +209,79 @@
     decorated = with_toplevel_context_create_module(f)
     decorated()
     return decorated
+
+
+def _debug_flags_impl(flags: Sequence[str]) -> Iterator[None]:
+    from mlir.ir import _GlobalDebug
+
+    # Save the original debug state. The debug flags will be popped rather than
+    # manually copied and saved for later.
+    original_flag = _GlobalDebug.flag
+    _GlobalDebug.flag = True
+    _GlobalDebug.push_debug_only_flags(flags)
+
+    try:
+        yield
+    finally:
+        # Reset the global debug flag and remove the most recent flags that were
+        # appended. This assumes that nothing else popped when it should not have.
+        _GlobalDebug.flag = original_flag
+        _GlobalDebug.pop_debug_only_flags()
+
+
+@contextmanager
+def debug_flags_context(flags: Sequence[str]):
+    """Temporarily create a context that enables debugging with specified filters.
+
+    These would be the same as running with -debug-only=*flags. Where multiple contexts
+    will be joined together to create the full list if they are nested.
+
+    This requires that the core MLIR units were compiled without NDEBUG.
+    """
+    return _debug_flags_impl(flags)
+
+
+@contextmanager
+def debug_conversion(flags: Sequence[str] = []) -> Iterator[None]:
+    """Temporarily create a context that enables full conversion debugging,
+    potentially with additional specified filters.
+
+    These would be the same as running with -debug-only=*flags. Where multiple contexts
+    will be joined together to create the full list if they are nested.
+
+    This requires that the core MLIR units were compiled without NDEBUG.
+    """
+    return _debug_flags_impl(list(flags) + ["dialect-conversion"])
+
+
+@contextmanager
+def debug_greedy_rewriter(flags: Sequence[str] = []) -> Iterator[None]:
+    """Temporarily create a context that enables full conversion debugging,
+    potentially with additional specified filters.
+
+    These would be the same as running with -debug-only=*flags. Where multiple contexts
+    will be joined together to create the full list if they are nested.
+
+    This requires that the core MLIR units were compiled without NDEBUG.
+    """
+    return _debug_flags_impl(list(flags) + ["greedy_rewriter"])
+
+
+@contextmanager
+def debug_td(flags: Sequence[str] = [], *, full_debug: bool = False) -> Iterator[None]:
+    """Temporarily create a context that enables full transform dialect debugging,
+    potentially with additional specified filters.
+
+    These would be the same as running with -debug-only=*flags. Where multiple contexts
+    will be joined together to create the full list if they are nested.
+
+    This requires that the core MLIR units were compiled without NDEBUG.
+    """
+    return _debug_flags_impl(
+        list(flags)
+        + [
+            "transform-dialect",
+            "transform-dialect-print-top-level-after-all",
+        ]
+        + (["transform-dialect-full"] if full_debug else [])
+    )
diff --git a/mlir/test/python/utils.py b/mlir/test/python/utils.py
index 8435fdd..c700808 100644
--- a/mlir/test/python/utils.py
+++ b/mlir/test/python/utils.py
@@ -1,13 +1,15 @@
 # RUN: %python %s | FileCheck %s
+# RUN: %python %s 2>&1 | FileCheck %s --check-prefix=DEBUG_ONLY
 
 import unittest
 
 from mlir import ir
-from mlir.dialects import arith, builtin
-from mlir.extras import types as T
+from mlir.passmanager import PassManager
+from mlir.dialects import arith
 from mlir.utils import (
     call_with_toplevel_context_create_module,
     caller_mlir_context,
+    debug_conversion,
     using_mlir_context,
 )
 
@@ -31,6 +33,7 @@
             c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result
             multiple_adds(c, c)
 
+            # CHECK-LABEL: module {
             # CHECK: constant
             # CHECK-NEXT: arith.addf
             # CHECK-NEXT: arith.addf
@@ -54,5 +57,25 @@
             pass
 
 
+class TestDebugOnlyFlags(unittest.TestCase):
+    def test_debug_types(self):
+        """Test checks --debug-only=xxx functionality is available in MLIR."""
+
+        @debug_conversion()
+        def lower(module) -> None:
+            pm = PassManager("builtin.module")
+            pm.add("convert-arith-to-llvm")
+            pm.run(module.operation)
+
+        @call_with_toplevel_context_create_module
+        def _(module) -> None:
+            c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result
+            arith.AddFOp(c, c, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf)
+
+            # DEBUG_ONLY-LABEL: Legalizing operation : 'builtin.module'
+            #       DEBUG_ONLY: Legalizing operation : 'arith.addf'
+            lower(module)
+
+
 if __name__ == "__main__":
     unittest.main()