Skip to content

Commit

Permalink
SPIR-V requires that kernels return void, while functions return empt…
Browse files Browse the repository at this point in the history
…y struct.
  • Loading branch information
elliottslaughter committed Nov 21, 2023
1 parent 4d8cf3b commit d824ce7
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions src/tcompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1140,14 +1140,14 @@ struct CCallingConv {

return Argument(C_AGGREGATE_REG, t, StructType::get(*CU->TT->ctx, elements));
}
void Classify(Obj *ftype, Obj *params, Classification *info) {
void Classify(Obj *ftype, CallingConv::ID cconv, Obj *params, Classification *info) {
Obj fparams, returntype;
ftype->obj("parameters", &fparams);
ftype->obj("returntype", &returntype);
int zero = 0;
info->returntype = ClassifyArgument(&returntype, &zero, &zero, true);

if (return_empty_struct_as_void) {
if (return_empty_struct_as_void || cconv == CallingConv::SPIR_KERNEL) {
// windows classifies empty structs as pass by pointer, but we need a return
// value of unit (an empty tuple) to be translated to void. So if it is unit,
// force the return value to be void by overriding the normal classification
Expand All @@ -1173,13 +1173,13 @@ struct CCallingConv {
CreateFunctionType(info, fparams.size(), ftype->boolean("isvararg"));
}

Classification *ClassifyFunction(Obj *fntyp) {
Classification *ClassifyFunction(Obj *fntyp, CallingConv::ID cconv) {
Classification *info = (Classification *)CU->symbols->getud(fntyp);
if (!info) {
info = new Classification(); // TODO: fix leak
Obj params;
fntyp->obj("parameters", &params);
Classify(fntyp, &params, info);
Classify(fntyp, cconv, &params, info);
CU->symbols->setud(fntyp, info);
}
return info;
Expand Down Expand Up @@ -1291,8 +1291,9 @@ struct CCallingConv {
}
}

Function *CreateFunction(Module *M, Obj *ftype, const Twine &name) {
Classification *info = ClassifyFunction(ftype);
Function *CreateFunction(Module *M, Obj *ftype, CallingConv::ID cconv,
const Twine &name) {
Classification *info = ClassifyFunction(ftype, cconv);
Function *fn = Function::Create(info->fntype, Function::InternalLinkage, name, M);
AttributeFnOrCall(fn, info);
return fn;
Expand Down Expand Up @@ -1323,7 +1324,7 @@ struct CCallingConv {
}
void EmitEntry(IRBuilder<> *B, Obj *ftype, Function *func,
std::vector<Value *> *variables) {
Classification *info = ClassifyFunction(ftype);
Classification *info = ClassifyFunction(ftype, func->getCallingConv());
assert(info->paramtypes.size() == variables->size());
Function::arg_iterator ai = func->arg_begin();
if (info->returntype.kind == C_AGGREGATE_MEM)
Expand Down Expand Up @@ -1371,7 +1372,7 @@ struct CCallingConv {
}
}
void EmitReturn(IRBuilder<> *B, Obj *ftype, Function *function, Value *result) {
Classification *info = ClassifyFunction(ftype);
Classification *info = ClassifyFunction(ftype, function->getCallingConv());
ArgumentKind kind = info->returntype.kind;

if (C_AGGREGATE_REG == kind &&
Expand Down Expand Up @@ -1432,10 +1433,10 @@ struct CCallingConv {
}
}

Value *EmitCall(IRBuilder<> *B, Obj *ftype, Obj *paramtypes, Value *callee,
std::vector<Value *> *actuals) {
Value *EmitCall(IRBuilder<> *B, Obj *ftype, CallingConv::ID cconv, Obj *paramtypes,
Value *callee, std::vector<Value *> *actuals) {
Classification info;
Classify(ftype, paramtypes, &info);
Classify(ftype, cconv, paramtypes, &info);

std::vector<Value *> arguments;

Expand Down Expand Up @@ -1879,16 +1880,25 @@ struct FunctionEmitter {
if (fstate->func) return fstate;
}

CallingConv::ID callingconv = CallingConv::MaxID;
if (funcobj->hasfield("callingconv")) {
callingconv = ParseCallingConv(funcobj->string("callingconv"));
}

Obj ftype;
funcobj->obj("type", &ftype);
// function name is $+name so that it can't conflict with any symbols imported
// from the C namespace
fstate->func = CC->CreateFunction(
M, &ftype, Twine(StringRef((isextern) ? "" : "$"), name));
fstate->func =
CC->CreateFunction(M, &ftype, callingconv,
Twine(StringRef((isextern) ? "" : "$"), name));
if (isextern) {
// Set external linkage for extern functions.
fstate->func->setLinkage(GlobalValue::ExternalLinkage);
}
if (callingconv != CallingConv::MaxID) {
fstate->func->setCallingConv(callingconv);
}

if (funcobj->hasfield("alwaysinline")) {
if (funcobj->boolean("alwaysinline")) {
Expand All @@ -1903,10 +1913,6 @@ struct FunctionEmitter {
fstate->func->addFnAttr(Attribute::NoInline);
}
}
if (funcobj->hasfield("callingconv")) {
const char *callingconv = funcobj->string("callingconv");
fstate->func->setCallingConv(ParseCallingConv(callingconv));
}
if (funcobj->hasfield("noreturn")) {
if (funcobj->boolean("noreturn")) {
fstate->func->addFnAttr(Attribute::NoReturn);
Expand Down Expand Up @@ -3217,6 +3223,11 @@ struct FunctionEmitter {

call->obj("value", &func);

CallingConv::ID callingconv = CallingConv::MaxID;
if (func.hasfield("callingconv")) {
callingconv = ParseCallingConv(func.string("callingconv"));
}

Value *fn = emitExp(&func);

Obj fnptrtyp;
Expand All @@ -3232,7 +3243,7 @@ struct FunctionEmitter {
setInsertBlock(bb);
deferred.push_back(bb);
}
Value *r = CC->EmitCall(B, &fntyp, &paramtypes, fn, &actuals);
Value *r = CC->EmitCall(B, &fntyp, callingconv, &paramtypes, fn, &actuals);
setInsertBlock(cur); // defer may have changed it
return r;
}
Expand Down

0 comments on commit d824ce7

Please sign in to comment.