Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle autodiff for lib builds #137570

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::ptr;

use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
use rustc_codegen_ssa::ModuleCodegen;
use rustc_codegen_ssa::back::write::ModuleConfig;
use rustc_errors::FatalError;
use tracing::{debug, trace};
Expand Down Expand Up @@ -276,7 +275,7 @@ fn generate_enzyme_call<'ll>(
}

pub(crate) fn differentiate<'ll>(
module: &'ll ModuleCodegen<ModuleLlvm>,
module_llvm: &'ll ModuleLlvm,
cgcx: &CodegenContext<LlvmCodegenBackend>,
diff_items: Vec<AutoDiffItem>,
_config: &ModuleConfig,
Expand All @@ -286,7 +285,7 @@ pub(crate) fn differentiate<'ll>(
}

let diag_handler = cgcx.create_dcx();
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
let cx = SimpleCx { llmod: module_llvm.llmod(), llcx: module_llvm.llcx };

// First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
if !diff_items.is_empty()
Expand Down
31 changes: 26 additions & 5 deletions compiler/rustc_codegen_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,32 @@ impl WriteBackendMethods for LlvmCodegenBackend {
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError> {
if cgcx.lto != Lto::Fat {
let dcx = cgcx.create_dcx();
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO));
}
builder::autodiff::differentiate(module, cgcx, diff_fncs, config)
//if cgcx.lto != Lto::Fat {
// let dcx = cgcx.create_dcx();
// return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO));
//}
let module_llvm = &module.module_llvm;
builder::autodiff::differentiate(module_llvm, cgcx, diff_fncs, config)
}
fn autodiff_thin(
cgcx: &CodegenContext<Self>,
thin_module: &ThinModule<Self>,
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError> {
let dcx = cgcx.create_dcx();
let dcx = dcx.handle();

let module_name = &thin_module.shared.module_names[thin_module.idx];

// Right now the implementation we've got only works over serialized
// modules, so we create a fresh new LLVM context and parse the module
// into that context. One day, however, we may do this for upstream
// crates but for locally codegened modules we may be able to reuse
// that LLVM Context and Module.
let module_llvm = ModuleLlvm::parse(cgcx, module_name, thin_module.data(), dcx)?;

builder::autodiff::differentiate(&module_llvm, cgcx, diff_fncs, config)
}
}

Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_codegen_ssa/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,11 @@ impl<B: WriteBackendMethods> LtoModuleCodegen<B> {
match &self {
LtoModuleCodegen::Fat(module) => {
B::autodiff(cgcx, &module, diff_fncs, config)?;
},
LtoModuleCodegen::Thin(module) => {
B::autodiff_thin(cgcx, module, diff_fncs, config)?;
}
_ => panic!("autodiff called with non-fat LTO module"),
}

Ok(self)
}
}
Expand Down
11 changes: 8 additions & 3 deletions compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,18 @@ fn generate_lto_work<B: ExtraBackendMethods>(
vec![(WorkItem::LTO(module), 0)]
} else {
if !autodiff.is_empty() {
let dcx = cgcx.create_dcx();
dcx.handle().emit_fatal(AutodiffWithoutLto {});
//let dcx = cgcx.create_dcx();
//dcx.handle().emit_fatal(AutodiffWithoutLto {});
}
let config = cgcx.config(ModuleKind::Regular);
assert!(needs_fat_lto.is_empty());
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
.unwrap_or_else(|e| e.raise());
lto_modules
.into_iter()
.map(|module| {
let mut module =
unsafe { module.autodiff(cgcx, autodiff.clone(), config).unwrap_or_else(|e| e.raise()) };
let cost = module.cost();
(WorkItem::LTO(module), cost)
})
Expand Down Expand Up @@ -1459,6 +1462,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
if needs_fat_lto.is_empty()
&& needs_thin_lto.is_empty()
&& lto_import_only_modules.is_empty()
&& autodiff_items.is_empty()
{
// Nothing more to do!
break;
Expand All @@ -1472,13 +1476,14 @@ fn start_executing_work<B: ExtraBackendMethods>(
assert!(!started_lto);
started_lto = true;

let autodiff_items = mem::take(&mut autodiff_items);
let needs_fat_lto = mem::take(&mut needs_fat_lto);
let needs_thin_lto = mem::take(&mut needs_thin_lto);
let import_only_modules = mem::take(&mut lto_import_only_modules);

for (work, cost) in generate_lto_work(
&cgcx,
autodiff_items.clone(),
autodiff_items,
needs_fat_lto,
needs_thin_lto,
import_only_modules,
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_codegen_ssa/src/traits/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ pub trait WriteBackendMethods: 'static + Sized + Clone {
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError>;
fn autodiff_thin(
cgcx: &CodegenContext<Self>,
thin: &ThinModule<Self>,
diff_fncs: Vec<AutoDiffItem>,
config: &ModuleConfig,
) -> Result<(), FatalError>;
}

pub trait ThinBufferMethods: Send + Sync {
Expand Down
Loading