Skip to content

Instantly share code, notes, and snippets.

@angelz913
Last active June 6, 2024 14:36
Show Gist options
  • Save angelz913/28e0eee2fbac9534822a3e783322ede4 to your computer and use it in GitHub Desktop.
Save angelz913/28e0eee2fbac9534822a3e783322ede4 to your computer and use it in GitHub Desktop.

Table of Contents

  1. ConvertToLLVM
  2. ConvertToSPIRV(IREE)
  3. Comparison

ConvertToLLVM

ConvertToLLVMPass

  • A generic pass to convert to LLVM, using the ConvertToLLVMPatternInterface to delegate to dialects the injection of conversion patterns
for (Dialect *dialect : context->getLoadedDialects()) {
  // First time we encounter this dialect: if it implements the interface,
  // let's populate patterns !
  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
  if (!iface)
    continue;
    iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
                                                   tempPatterns);

ConvertToLLVMPaternInterface

Implementation of each dialect -> LLVM (e.g. ArithToLLVM)

struct ArithToLLVMConversionPass
    : public impl::ArithToLLVMConversionPassBase<ArithToLLVMConversionPass> {
  using Base::Base;

  void runOnOperation() override {
    LLVMConversionTarget target(getContext());
    RewritePatternSet patterns(&getContext());

    LowerToLLVMOptions options(&getContext());
    if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
      options.overrideIndexBitwidth(indexBitwidth);

    LLVMTypeConverter converter(&getContext(), options);
    mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);

    if (failed(applyPartialConversion(getOperation(), target,
                                      std::move(patterns))))
      signalPassFailure();
  }
};
patterns.add<
    AddFOpLowering,
    AddIOpLowering,
    ...>(converter)

ConvertToSPIRV(IREE)

ConvertToSPIRVPass

void getDependentDialects(DialectRegistry &registry) const override {
  registry.insert<spirv::SPIRVDialect>();
}
  • runOnOperation
    1. Go through all funcOp and check for attributes, including workgroupSize, subgroupSize, etc.
    2. Go through all funcOp again, try to apply RemoveStaticDynamicCast patterns and fold greedily via a call to applyPatternsAndFoldGreedily
    3. Populate VectorNarrowTypeRewritePatterns, rewriting extui/si(bitcast) as a mix of vector.shuffle + bitwise arithmetic.
    4. Populate ExpandBFloat16, expanding any remaining bf16 extf and trunc patterns
    5. Populate MMAToSPIRVCoopMatrixTypeConversion, for GPU subgroup MMA ops
    6. Populate more to-spirv patterns, including gpu, scf, memref, func, math, compelx, tensor, vector, etc.
    7. Add IREE HAL interface op conversions.
    8. Fold certain operations as no-ops
    9. Try applying full conversion by calling applyFullConversion
  FrozenRewritePatternSet frozenPatterns(std::move(patterns));
  for (auto fn : functions) {
    if (failed(applyFullConversion(fn, *target, frozenPatterns))) {
      return signalPassFailure();
    }
  }

Comparison

  • In ConvertToLLVMPass, the patterns are populated by the dialects themselves, via their interfaces. There is no need to call the individual populate...Patterns methods.
    • More modularized approach
    • Need to implement the interface for each dialect
  • However, in IREE's ConvertToSPIRVPass class, there is one large runOnOperation that populates the patterns by calling all the populate...Patterns method.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment