Skip to content

Commit a1e73f9

Browse files
authored
Unrolled build for #142097
Rollup merge of #142097 - ZuseZ4:offload-host1, r=oli-obk gpu offload host code generation r? ghost This will generate most of the host side code to use llvm's offload feature. The first PR will only handle automatic mem-transfers to and from the device. So if a user calls a kernel, we will copy inputs back and forth, but we won't do the actual kernel launch. Before merging, we will use LLVM's Info infrastructure to verify that the memcopies match what openmp offloa generates in C++. `LIBOMPTARGET_INFO=-1 ./my_rust_binary` should print that a memcpy to and later from the device is happening. A follow-up PR will generate the actual device-side kernel which will then do computations on the GPU. A third PR will implement manual host2device and device2host functionality, but the goal is to minimize cases where a user has to overwrite our default handling due to performance issues. I'm trying to get a full MVP out first, so this just recognizes GPU functions based on magic names. The final frontend will obviously move this over to use proper macros, like I'm already doing it for the autodiff work. This work will also be compatible with std::autodiff, so one can differentiate GPU kernels. Tracking: - #131513
2 parents 3f9f20f + c068599 commit a1e73f9

File tree

17 files changed

+755
-17
lines changed

17 files changed

+755
-17
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,7 @@ pub(crate) fn run_pass_manager(
654654
// We then run the llvm_optimize function a second time, to optimize the code which we generated
655655
// in the enzyme differentiation pass.
656656
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
657+
let enable_gpu = config.offload.contains(&config::Offload::Enable);
657658
let stage = if thin {
658659
write::AutodiffStage::PreAD
659660
} else {
@@ -668,6 +669,12 @@ pub(crate) fn run_pass_manager(
668669
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
669670
}
670671

672+
if enable_gpu && !thin {
673+
let cx =
674+
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);
675+
crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx);
676+
}
677+
671678
if cfg!(llvm_enzyme) && enable_ad && !thin {
672679
let cx =
673680
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ops::Deref;
33
use std::{iter, ptr};
44

55
pub(crate) mod autodiff;
6+
pub(crate) mod gpu_offload;
67

78
use libc::{c_char, c_uint, size_t};
89
use rustc_abi as abi;
@@ -117,6 +118,74 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
117118
}
118119
bx
119120
}
121+
122+
// The generic builder has less functionality and thus (unlike the other alloca) we can not
123+
// easily jump to the beginning of the function to place our allocas there. We trust the user
124+
// to manually do that. FIXME(offload): improve the genericCx and add more llvm wrappers to
125+
// handle this.
126+
pub(crate) fn direct_alloca(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value {
127+
let val = unsafe {
128+
let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED);
129+
llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint);
130+
// Cast to default addrspace if necessary
131+
llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED)
132+
};
133+
if name != "" {
134+
let name = std::ffi::CString::new(name).unwrap();
135+
llvm::set_value_name(val, &name.as_bytes());
136+
}
137+
val
138+
}
139+
140+
pub(crate) fn inbounds_gep(
141+
&mut self,
142+
ty: &'ll Type,
143+
ptr: &'ll Value,
144+
indices: &[&'ll Value],
145+
) -> &'ll Value {
146+
unsafe {
147+
llvm::LLVMBuildGEPWithNoWrapFlags(
148+
self.llbuilder,
149+
ty,
150+
ptr,
151+
indices.as_ptr(),
152+
indices.len() as c_uint,
153+
UNNAMED,
154+
GEPNoWrapFlags::InBounds,
155+
)
156+
}
157+
}
158+
159+
pub(crate) fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align) -> &'ll Value {
160+
debug!("Store {:?} -> {:?}", val, ptr);
161+
assert_eq!(self.cx.type_kind(self.cx.val_ty(ptr)), TypeKind::Pointer);
162+
unsafe {
163+
let store = llvm::LLVMBuildStore(self.llbuilder, val, ptr);
164+
llvm::LLVMSetAlignment(store, align.bytes() as c_uint);
165+
store
166+
}
167+
}
168+
169+
pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value {
170+
unsafe {
171+
let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED);
172+
llvm::LLVMSetAlignment(load, align.bytes() as c_uint);
173+
load
174+
}
175+
}
176+
177+
fn memset(&mut self, ptr: &'ll Value, fill_byte: &'ll Value, size: &'ll Value, align: Align) {
178+
unsafe {
179+
llvm::LLVMRustBuildMemSet(
180+
self.llbuilder,
181+
ptr,
182+
align.bytes() as c_uint,
183+
fill_byte,
184+
size,
185+
false,
186+
);
187+
}
188+
}
120189
}
121190

122191
/// Empty string, to be used where LLVM expects an instruction name, indicating

0 commit comments

Comments
 (0)