Skip to content

Instantly share code, notes, and snippets.

@kennytm
Created November 30, 2012 20:49
Show Gist options
  • Save kennytm/4178490 to your computer and use it in GitHub Desktop.
Save kennytm/4178490 to your computer and use it in GitHub Desktop.
Substituting template arguments in the clang C++ library
#include <type_traits>
#include <cstdio>
namespace L {
template <int n>
struct N {
static constexpr bool equals(int m) { return m == n; }
};
struct B {
template <typename M>
using A = typename std::conditional<M::equals(4), int*, void*>::type;
};
template <typename T, T value>
struct F : B {};
template <>
struct F<decltype(&fopen), &fopen> : B {
template <typename M>
using A = double*;
};
template <>
struct F<decltype(&fclose), &fclose> : B {
template <typename M>
using A = typename std::conditional<M::equals(16), void*, char**>::type;
};
}
#include <cstdlib>
#include <vector>
#include <string>
#include <algorithm>
#include <iostream>
#include <unordered_map>
#include <boost/filesystem.hpp>
#define __STDC_LIMIT_MACROS
#define __STDC_CONSTANT_MACROS
#include <llvm/ADT/ArrayRef.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Support/Host.h>
#include <llvm/Support/Regex.h>
#include <clang/Frontend/DiagnosticOptions.h>
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Frontend/ASTUnit.h>
#include <clang/Frontend/Utils.h>
#include <clang/AST/Decl.h>
#include <clang/AST/DeclTemplate.h>
#include <clang/Sema/SemaDiagnostic.h>
#include <clang/Sema/Template.h>
using namespace clang;
/**
* A convenient structure for iterating a some specific decls from a DeclContext
*/
template <typename DeclType>
struct RangeOf {
typedef DeclContext::specific_decl_iterator<DeclType> iterator;
iterator _begin;
iterator _end;
explicit RangeOf(DeclContext* decl)
: _begin(decl->decls_begin()), _end(decl->decls_end())
{}
iterator begin() const { return _begin; }
iterator end() const { return _end; }
};
/**
* Fix some errors in the compiler invocation before sending it to clang.
*
* This is a hack! Some header search path by default is incorrect. It causes
* libclang not able to find stddef.h. This method is only tested on Arch Linux.
*/
void fix_invocation(CompilerInvocation& invoc) {
for (auto& entry : invoc.getHeaderSearchOpts().UserEntries) {
if (entry.IsInternal) {
if (!boost::filesystem::is_directory(entry.Path)) {
entry.Path = "/usr/bin/" + entry.Path;
}
}
}
}
/**
* Get the AST unit from the files provided in the command line.
*/
std::unique_ptr<ASTUnit> get_ast_unit(int argc, const char** argv) {
DiagnosticOptions diag_opts;
diag_opts.ShowColors = true;
auto diags = CompilerInstance::createDiagnostics(diag_opts, argc, argv);
std::vector<const char*> command_args_vector {"-std=c++11",
"-stdlib=libc++",
"-x", "c++"};
std::copy_n(argv + 1, argc - 1, std::back_inserter(command_args_vector));
auto command_args = llvm::makeArrayRef(command_args_vector);
auto raw_invoc = clang::createInvocationFromCommandLine(command_args, diags);
std::unique_ptr<CompilerInvocation> invoc (raw_invoc);
if (invoc == nullptr) {
return nullptr;
}
fix_invocation(*invoc);
auto raw_unit = ASTUnit::LoadFromCompilerInvocation(invoc.release(), diags,
/*OnlyLocalDecls*/false, /*CaptureDiagnostics*/false);
return std::unique_ptr<ASTUnit>(raw_unit);
}
/**
* Instantiate a class template.
*/
ClassTemplateSpecializationDecl* instantiate(ASTContext& ast, Sema& sema, DeclContext* parent,
ClassTemplateDecl* decl, llvm::ArrayRef<TemplateArgument> args) {
void* ins_point;
auto retval = decl->findSpecialization(args.data(), args.size(), ins_point);
if (retval == nullptr) {
retval = ClassTemplateSpecializationDecl::Create(ast, TTK_Class, parent, {}, {}, decl,
args.data(), args.size(), nullptr);
decl->AddSpecialization(retval, ins_point);
}
bool is_incomplete = sema.RequireCompleteType({}, ast.getTypeDeclType(retval), diag::err_incomplete_type);
return is_incomplete ? nullptr : retval;
}
/**
* Instantiate a template alias (`template <...> using Foo = ...`).
*/
TypeAliasDecl* instantiate(ASTContext& ast, Sema& sema, DeclContext* parent,
TypeAliasTemplateDecl* decl, llvm::ArrayRef<TemplateArgument> args) {
auto args_count = static_cast<unsigned>(args.size());
TemplateArgumentList arg_list {TemplateArgumentList::OnStack, args.data(), args_count};
MultiLevelTemplateArgumentList multi_arg_list {arg_list};
TemplateDeclInstantiator instantiator {sema, parent, multi_arg_list};
auto instantiated = instantiator.Visit(decl);
if (auto inst_decl = dyn_cast<TypeAliasTemplateDecl>(instantiated)) {
return inst_decl->getTemplatedDecl();
}
return nullptr;
}
/**
* Make a template argument for an integer.
*/
TemplateArgument int_argument(ASTContext& ast, int value) {
auto type = ast.IntTy;
auto bits = static_cast<unsigned>(ast.getTypeSize(type));
llvm::APInt ap {bits, static_cast<uint64_t>(value), /*isSigned*/true};
auto literal = IntegerLiteral::Create(ast, ap, type, {});
return TemplateArgument{literal};
}
/**
* Find the first declaration as a descendant of the DeclContext in DFS order.
*/
template <typename DeclType>
DeclType* find(DeclContext* parent, const char* name) {
for (auto decl : RangeOf<DeclType>(parent)) {
if (decl->getName() == name) {
return decl;
}
if (auto decl_context = dyn_cast<DeclContext>(decl)) {
if (auto deeper_result = find<DeclType>(decl_context, name)) {
return deeper_result;
}
}
}
return nullptr;
}
/**
* Find the first declaration as a descendant of a range of DeclContext in DFS
* order.
*/
template <typename DeclType, typename It>
DeclType* find_among(It begin, It end, const char* name) {
while (begin != end) {
if (auto decl_context = dyn_cast<DeclContext>(*begin)) {
if (auto result = find<DeclType>(decl_context, name)) {
return result;
}
}
++ begin;
}
return nullptr;
}
/**
* Find the first declaration as a descendant of a C++ record, including all its
* bases classes (direct and indirect).
*/
template <typename DeclType>
DeclType* find_with_bases(CXXRecordDecl* decl, const char* name) {
if (auto result = find<DeclType>(decl, name)) {
return result;
}
auto it = decl->bases_begin();
auto end = decl->bases_end();
while (it != end) {
auto base_decl = it->getType()->getAsCXXRecordDecl();
if (auto result = find_with_bases<DeclType>(base_decl, name)) {
return result;
}
++ it;
}
return nullptr;
}
int main(int argc, const char** argv) {
std::string function_name;
int n_number;
std::cin >> function_name >> n_number;
auto unit = get_ast_unit(argc, argv);
if (unit == nullptr) {
return 1;
}
auto& ast = unit->getASTContext();
auto& sema = unit->getSema();
auto ns_decl = find<NamespaceDecl>(ast.getTranslationUnitDecl(), "L");
auto n_decl = find<ClassTemplateDecl>(ns_decl, "N");
auto f_decl = find<ClassTemplateDecl>(ns_decl, "F");
auto fopen_decl = find_among<FunctionDecl>(unit->top_level_begin(), unit->top_level_end(), function_name.c_str());
auto fptr_type = ast.getPointerType(fopen_decl->getType()).getCanonicalType();
TemplateArgument f_args[] = {fptr_type, fopen_decl};
auto f_inst = instantiate(ast, sema, ns_decl, f_decl, llvm::makeArrayRef(f_args));
auto a_decl = find_with_bases<TypeAliasTemplateDecl>(f_inst, "A");
TemplateArgument n_args[] = {int_argument(ast, n_number)};
auto n_inst = instantiate(ast, sema, ns_decl, n_decl, llvm::makeArrayRef(n_args));
TemplateArgument a_args[] = {QualType{n_inst->getTypeForDecl(), 0}};
auto a_inst = instantiate(ast, sema, ns_decl, a_decl, llvm::makeArrayRef(a_args));
auto a_type = a_inst->getUnderlyingType().getCanonicalType();
a_type->dump();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment