Created
January 5, 2020 22:58
-
-
Save flaub/3a2448daf2ed50df2f25e77c7fb42d77 to your computer and use it in GitHub Desktop.
Experiment with DRR + DialectConversion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct LoweringPass : public mlir::ModulePass<LoweringPass> { | |
void runOnModule() final { | |
// Set up target (i.e. what is legal) | |
// ... | |
// Setup rewrite patterns | |
OwningRewritePatternList patterns; | |
populateWithGenerated(&getContext(), patterns); | |
// Run the conversion | |
if (failed(applyPartialConversion(getModule(), target, patterns, nullptr))) { | |
signalPassFailure(); | |
return; | |
} | |
} | |
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
template <typename OpType> | |
Value eltwiseOpToParallelFor(OpBuilder& builder, Value from) { | |
auto rewriter = dynamic_cast<mlir::ConversionPatternRewriter*>(&builder); | |
if (!rewriter) { | |
// FIXME: how to report failure? | |
return {}; | |
} | |
TypeConverter typeConverter; | |
auto op = from.getDefiningOp(); | |
auto loc = op->getLoc(); | |
auto resultType = op->getResult(0)->getType(); | |
auto resultMemRefType = typeConverter.convertType(resultType).cast<MemRefType>(); | |
auto resultMemRef = builder.create<AllocOp>(loc, resultMemRefType).getResult(); | |
auto ranges = builder.getI64ArrayAttr(resultMemRefType.getShape()); | |
auto dynamicRanges = ArrayRef<Value>(); | |
auto forOp = builder.create<AffineParallelForOp>(loc, ranges, dynamicRanges); | |
auto body = builder.createBlock(&forOp.inner()); | |
SmallVector<Value, 8> idxs; | |
for (size_t i = 0; i < ranges.size(); i++) { | |
idxs.push_back(body->addArgument(builder.getIndexType())); | |
} | |
SmallVector<Value, 4> scalars; | |
for (auto operand : op->getOperands()) { | |
auto memref = rewriter->getRemappedValue(operand); | |
scalars.push_back(builder.create<AffineLoadOp>(loc, memref, idxs)); | |
} | |
auto attrs = ArrayRef<NamedAttribute>{}; | |
auto elementType = resultMemRefType.getElementType(); | |
auto resultTypes = llvm::makeArrayRef(elementType); | |
auto result = builder.create<OpType>(loc, resultTypes, scalars, attrs); | |
builder.create<AffineStoreOp>(loc, result, resultMemRef, idxs); | |
builder.create<AffineTerminatorOp>(loc); | |
return resultMemRef; | |
} | |
#include "rewrites.cc.inc" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
namespace mlir { | |
class MLIRContext; | |
class OwningRewritePatternList; | |
} // namespace mlir | |
void populateWithGenerated(mlir::MLIRContext* context, mlir::OwningRewritePatternList* patterns); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef __PML_CONVERSION_ELTWISE_TO_PXA__ | |
#define __PML_CONVERSION_ELTWISE_TO_PXA__ | |
include "pmlc/dialect/eltwise/ops.td" | |
include "mlir/Dialect/StandardOps/Ops.td" | |
class EltwiseOpToParallelFor<Op op> : NativeCodeCall< | |
"eltwiseOpToParallelFor<" # op # ">($_builder, $0)">; | |
class EltwiseOpConversionPat<Op from, Op into, TypeConstraint cons> : Pat< | |
(from:$op $lhs, $rhs), | |
(EltwiseOpToParallelFor<into> $op), | |
[(cons $op)]>; | |
def : EltwiseOpConversionPat<EW_AddOp, AddFOp, EltwiseFloat>; | |
#endif // __PML_CONVERSION_ELTWISE_TO_PXA__ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment