Skip to content

Instantly share code, notes, and snippets.

@stellaraccident
Created March 21, 2020 02:26
Show Gist options
  • Select an option

  • Save stellaraccident/79214621553f775da5a7afb67839b3bc to your computer and use it in GitHub Desktop.

Select an option

Save stellaraccident/79214621553f775da5a7afb67839b3bc to your computer and use it in GitHub Desktop.
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "third_party/absl/strings/string_view.h"
#include "third_party/iree/iree/base/signature_mangle.h"
#include "third_party/iree/iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "third_party/iree/iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "third_party/iree/iree/compiler/Dialect/VM/IR/VMOps.h"
#include "third_party/iree/iree/compiler/Dialect/VM/IR/VMTypes.h"
#include "third_party/iree/iree/compiler/Dialect/VM/Transforms/Passes.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/ErrorHandling.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
using iree::RawSignatureParser;
using iree::AbiConstants::ScalarType;
namespace mlir {
namespace iree_compiler {
namespace IREE {
namespace VM {
namespace {
class AutoImporter {
public:
AutoImporter(Operation *moduleOp)
: moduleOp(moduleOp), symbolTable(moduleOp) {}
void lookup(StringRef importName) {
auto ref = symbolTable.lookup<ImportOp>(importName);
///// I DON'T KNOW WHAT I'M DOING
}
private:
Operation *moduleOp;
SymbolTable symbolTable;
};
Type mapScalarType(MLIRContext *ctx, ScalarType scalarType) {
switch (scalarType) {
case ScalarType::kIeeeFloat32:
return FloatType::getF32(ctx);
case ScalarType::kIeeeFloat64:
return FloatType::getF64(ctx);
case ScalarType::kIeeeFloat16:
return FloatType::getF16(ctx);
case ScalarType::kGoogleBfloat16:
return FloatType::getBF16(ctx);
case ScalarType::kSint32:
case ScalarType::kUint32:
return IntegerType::get(32, ctx);
case ScalarType::kSint64:
case ScalarType::kUint64:
return IntegerType::get(64, ctx);
case ScalarType::kSint16:
case ScalarType::kUint16:
return IntegerType::get(16, ctx);
case ScalarType::kSint8:
case ScalarType::kUint8:
return IntegerType::get(8, ctx);
default:
return nullptr;
}
}
LogicalResult mapRawAbiTypes(
Location loc, SmallVectorImpl<RawSignatureParser::Description> &descs,
SmallVectorImpl<Type> &types) {
auto *ctx = loc.getContext();
auto bufferViewType = HAL::BufferViewType::get(loc.getContext());
auto bufferViewRefType = RefType::get(bufferViewType);
for (auto &d : descs) {
switch (d.type) {
case RawSignatureParser::Type::kBuffer:
// ABI buffers map to shape-erased ref of buffer_views.
types.push_back(bufferViewRefType);
break;
case RawSignatureParser::Type::kRefObject: {
// TODO(laurenzo): Map supported ref objects.
std::string dstr;
d.ToString(dstr);
return emitError(loc) << "unsupported ABI type: " << dstr;
}
case RawSignatureParser::Type::kScalar: {
auto t = mapScalarType(ctx, d.scalar.type);
if (!t) {
std::string dstr;
d.ToString(dstr);
return emitError(loc) << "unsupported ABI type: " << dstr;
}
types.push_back(t);
break;
}
}
}
return success();
}
LogicalResult generateSynchronousBody(
FuncOp funcOp, OpBuilder moduleBuilder, SmallVectorImpl<Type> &inputTypes,
SmallVectorImpl<RawSignatureParser::Description> &inputDescs,
SmallVectorImpl<Type> &resultTypes,
SmallVectorImpl<RawSignatureParser::Description> &resultDescs) {
auto loc = funcOp.getLoc();
Block *entryBlock = funcOp.addEntryBlock();
OpBuilder builder(entryBlock);
// Build call operands.
SmallVector<Value, 4> callOperands;
for (const auto &input : llvm::enumerate(inputDescs)) {
auto blockArg = entryBlock->getArgument(input.index());
switch (input.value().type) {
case RawSignatureParser::Type::kBuffer: {
// Pass the backing buffer.
// TODO(laurenzo): Validate shape.
callOperands.push_back(builder.create<HAL::BufferViewBufferOp>(
loc, blockArg));
// Now, each dynamic dim is passed individually.
for (auto dim : llvm::enumerate(input.value().dims)) {
if (dim.value() >= 0) {
// Static.
continue;
}
// Dynamic.
// TODO(laurenzo): How to get the shape dim???
auto dimValue = builder.create<ConstI32Op>(loc, 1);
callOperands.push_back(dimValue);
}
break;
}
case RawSignatureParser::Type::kScalar: {
// Assume that scalars are pass-through.
callOperands.push_back(blockArg);
break;
}
case RawSignatureParser::Type::kRefObject: {
// Assume that ref objects are pass-through.
callOperands.push_back(blockArg);
break;
}
}
}
return success();
}
LogicalResult generateRawAbiFunctions(OpBuilder &moduleBuilder, FuncOp funcOp,
DictionaryAttr reflection,
StringRef signatureSr) {
auto ctx = funcOp.getContext();
absl::string_view signature(signatureSr.data(), signatureSr.size());
SmallVector<RawSignatureParser::Description, 4> inputDescs;
SmallVector<RawSignatureParser::Description, 4> resultDescs;
// Parse the reflection metadata.
RawSignatureParser p;
p.VisitInputs(signature, [&](const RawSignatureParser::Description &d) {
inputDescs.push_back(d);
});
p.VisitResults(signature, [&](const RawSignatureParser::Description &d) {
resultDescs.push_back(d);
});
if (p.GetError()) {
return funcOp.emitError() << "illegal abi signature ('" << signatureSr
<< "'): " << *p.GetError();
}
// Map to function signature types.
SmallVector<Type, 4> inputTypes;
SmallVector<Type, 4> resultTypes;
if (failed(mapRawAbiTypes(funcOp.getLoc(), inputDescs, inputTypes))) {
return failure();
}
assert(inputTypes.size() == inputDescs.size());
if (failed(mapRawAbiTypes(funcOp.getLoc(), resultDescs, resultTypes))) {
return failure();
}
assert(resultTypes.size() == resultDescs.size());
// Create the new synchronus function export.
SmallVector<NamedAttribute, 1> exportAttrs;
exportAttrs.push_back(moduleBuilder.getNamedAttr(
"iree.reflection", moduleBuilder.getStringAttr(signatureSr)));
auto syncType = FunctionType::get(inputTypes, resultTypes, ctx);
auto syncName = (funcOp.getName() + "$sync").str();
auto syncFuncOp = moduleBuilder.create<FuncOp>(funcOp.getLoc(), syncName,
syncType, exportAttrs);
// Export the syncronous as the original function name.
moduleBuilder.create<ExportOp>(funcOp.getLoc(),
moduleBuilder.getSymbolRefAttr(syncFuncOp),
funcOp.getName());
if (failed(generateSynchronousBody(syncFuncOp, moduleBuilder, inputTypes,
inputDescs, resultTypes, resultDescs))) {
return failure();
}
return success();
}
LogicalResult generateAbiFunctions(FuncOp funcOp, DictionaryAttr reflection) {
OpBuilder builder(funcOp.getContext());
builder.setInsertionPointAfter(funcOp);
auto rawSignatureSpec = reflection.get("f").dyn_cast_or_null<StringAttr>();
if (rawSignatureSpec) {
if (failed(generateRawAbiFunctions(builder, funcOp, reflection,
rawSignatureSpec.getValue()))) {
return failure();
}
}
return success();
}
class PublicABIGenerationPass
: public OperationPass<PublicABIGenerationPass, ModuleOp> {
public:
void runOnOperation() override {
for (auto &op : getOperation().getBlock().getOperations()) {
if (auto funcOp = dyn_cast<FuncOp>(op)) {
auto reflection = funcOp.getAttr("iree.generateabi.reflection")
.dyn_cast_or_null<DictionaryAttr>();
if (reflection) {
if (failed(generateAbiFunctions(funcOp, reflection))) {
signalPassFailure();
return;
}
}
}
}
}
};
} // namespace
std::unique_ptr<OpPassBase<IREE::VM::ModuleOp>>
createPublicABIGenerationPass() {
return std::make_unique<PublicABIGenerationPass>();
}
static PassRegistration<PublicABIGenerationPass> pass(
"iree-vm-public-abi-generation", "Creates public ABI entry points");
} // namespace VM
} // namespace IREE
} // namespace iree_compiler
} // namespace mlir
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment