blob: 1d3b24a7aee15d941c2d7c18538c96217417791f [file] [log] [blame] [edit]
//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// DialectInterfaceGen generates definitions for Dialect interfaces.
//
//===----------------------------------------------------------------------===//
#include "CppGenUtilities.h"
#include "DocGenUtilities.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/CodeGenHelpers.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace mlir;
using llvm::Record;
using llvm::RecordKeeper;
using mlir::tblgen::Interface;
using mlir::tblgen::InterfaceMethod;
/// Emit a string corresponding to a C++ type, followed by a space if necessary.
static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
type = type.trim();
os << type;
if (type.back() != '&' && type.back() != '*')
os << " ";
return os;
}
/// Emit the method name and argument list for the given method.
static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
raw_ostream &os) {
os << name << '(';
llvm::interleaveComma(method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) {
os << arg.type << " " << arg.name;
});
os << ") const";
}
/// Get an array of all Dialect Interface definitions
static std::vector<const Record *>
getAllInterfaceDefinitions(const RecordKeeper &records) {
std::vector<const Record *> defs =
records.getAllDerivedDefinitions("DialectInterface");
llvm::erase_if(defs, [&](const Record *def) {
// Ignore interfaces defined outside of the top-level file.
return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
llvm::SrcMgr.getMainFileID();
});
return defs;
}
namespace {
/// This struct is the generator used when processing tablegen dialect
/// interfaces.
class DialectInterfaceGenerator {
public:
DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
: defs(getAllInterfaceDefinitions(records)), os(os) {}
bool emitInterfaceDecls();
protected:
void emitInterfaceDecl(const Interface &interface);
/// The set of interface records to emit.
std::vector<const Record *> defs;
// The stream to emit to.
raw_ostream &os;
};
} // namespace
//===----------------------------------------------------------------------===//
// GEN: Interface declarations
//===----------------------------------------------------------------------===//
static void emitInterfaceMethodDoc(const InterfaceMethod &method,
raw_ostream &os, StringRef prefix = "") {
if (std::optional<StringRef> description = method.getDescription())
tblgen::emitDescriptionComment(*description, os, prefix);
}
static void emitInterfaceMethodsDef(const Interface &interface,
raw_ostream &os) {
raw_indented_ostream ios(os);
ios.indent(2);
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, ios);
ios << "virtual ";
emitCPPType(method.getReturnType(), ios);
emitMethodNameAndArgs(method, method.getName(), ios);
ios << " {";
if (auto body = method.getBody()) {
ios << "\n";
ios.indent(4);
ios << body << "\n";
ios.indent(2);
}
os << "}\n";
}
}
void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
StringRef interfaceName = interface.getName();
tblgen::emitSummaryAndDescComments(os, "",
interface.getDescription().value_or(""));
// Emit the main interface class declaration.
os << llvm::formatv(
"class {0} : public ::mlir::DialectInterface::Base<{0}> {\n"
"public:\n"
" {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
interfaceName);
emitInterfaceMethodsDef(interface, os);
os << "};\n";
}
bool DialectInterfaceGenerator::emitInterfaceDecls() {
llvm::emitSourceFileHeader("Dialect Interface Declarations", os);
// Sort according to ID, so defs are emitted in the order in which they appear
// in the Tablegen file.
std::vector<const Record *> sortedDefs(defs);
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
return lhs->getID() < rhs->getID();
});
for (const Record *def : sortedDefs)
emitInterfaceDecl(Interface(def));
return false;
}
//===----------------------------------------------------------------------===//
// GEN: Interface registration hooks
//===----------------------------------------------------------------------===//
static mlir::GenRegistration genDecls(
"gen-dialect-interface-decls", "Generate dialect interface declarations.",
[](const RecordKeeper &records, raw_ostream &os) {
return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
});