Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff Upstreaming - enzyme frontend #129458

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 281 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
//! we create an `AutoDiffItem` which contains the source and target function names. The source
//! is the function to which the autodiff attribute is applied, and the target is the function
//! getting generated by us (with a name given by the user as the first autodiff arg).

use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P;
use crate::{Ty, TyKind};

/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
/// are a hack to support higher order derivatives. We need to compute first order derivatives
/// before we compute second order derivatives, otherwise we would differentiate our placeholder
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
/// as it's already done in the C++ and Julia frontend of Enzyme. (FIXME) remove *First variants.
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffMode {
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
/// No autodiff is applied (used during error handling).
Error,
/// The primal function which we will differentiate.
Source,
/// The target function, to be created using forward mode AD.
Forward,
/// The target function, to be created using reverse mode AD.
Reverse,
/// The target function, to be created using forward mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ForwardFirst,
/// The target function, to be created using reverse mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ReverseFirst,
}

/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
/// we add to the previous shadow value. To not surprise users, we picked different names.
/// Dual numbers is also a quite well known name for forward mode AD types.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffActivity {
/// Implicit or Explicit () return type, so a special case of Const.
None,
/// Don't compute derivatives with respect to this input/output.
Const,
/// Reverse Mode, Compute derivatives for this scalar input/output.
Active,
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
/// the original return value.
ActiveOnly,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it.
Dual,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. Drop the code which updates the original input/output for maximum performance.
DualOnly,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
Duplicated,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
/// Drop the code which updates the original input for maximum performance.
DuplicatedOnly,
/// All Integers must be Const, but these are used to mark the integer which represents the
/// length of a slice/vec. This is used for safety checks on slices.
FakeActivitySize,
}
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
/// The name of the function getting differentiated
pub source: String,
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
/// Describe the memory layout of input types
pub inputs: Vec<TypeTree>,
/// Describe the memory layout of the output type
pub output: TypeTree,
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
/// e.g. in the [JAX
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
pub mode: DiffMode,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}

impl DiffMode {
pub fn is_rev(&self) -> bool {
matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
}
pub fn is_fwd(&self) -> bool {
matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
}
}

impl Display for DiffMode {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
DiffMode::Error => write!(f, "Error"),
DiffMode::Source => write!(f, "Source"),
DiffMode::Forward => write!(f, "Forward"),
DiffMode::Reverse => write!(f, "Reverse"),
DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
}
}
}

/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
if activity == DiffActivity::None {
// Only valid if primal returns (), but we can't check that here.
return true;
}
match mode {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
activity == DiffActivity::Dual
|| activity == DiffActivity::DualOnly
|| activity == DiffActivity::Const
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
activity == DiffActivity::Const
|| activity == DiffActivity::Active
|| activity == DiffActivity::ActiveOnly
}
}
}

/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
/// for the given argument, but we generally can't know the size of such a type.
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
/// users here from marking scalars as Duplicated, due to type aliases.
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
use DiffActivity::*;
// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
if matches!(activity, Const) {
return true;
}
if matches!(activity, Dual | DualOnly) {
return true;
}
// FIXME(ZuseZ4) We should make this more robust to also
// handle type aliases. Once that is done, we can be more restrictive here.
if matches!(activity, Active | ActiveOnly) {
return true;
}
matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
&& matches!(activity, Duplicated | DuplicatedOnly)
}
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
ZuseZ4 marked this conversation as resolved.
Show resolved Hide resolved
use DiffActivity::*;
return match mode {
DiffMode::Error => false,
DiffMode::Source => false,
DiffMode::Forward | DiffMode::ForwardFirst => {
matches!(activity, Dual | DualOnly | Const)
}
DiffMode::Reverse | DiffMode::ReverseFirst => {
matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
}
};
}

impl Display for DiffActivity {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
DiffActivity::None => write!(f, "None"),
DiffActivity::Const => write!(f, "Const"),
DiffActivity::Active => write!(f, "Active"),
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
DiffActivity::Dual => write!(f, "Dual"),
DiffActivity::DualOnly => write!(f, "DualOnly"),
DiffActivity::Duplicated => write!(f, "Duplicated"),
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
}
}
}

impl FromStr for DiffMode {
type Err = ();

fn from_str(s: &str) -> Result<DiffMode, ()> {
match s {
"Error" => Ok(DiffMode::Error),
"Source" => Ok(DiffMode::Source),
"Forward" => Ok(DiffMode::Forward),
"Reverse" => Ok(DiffMode::Reverse),
"ForwardFirst" => Ok(DiffMode::ForwardFirst),
"ReverseFirst" => Ok(DiffMode::ReverseFirst),
_ => Err(()),
}
}
}
impl FromStr for DiffActivity {
type Err = ();

fn from_str(s: &str) -> Result<DiffActivity, ()> {
match s {
"None" => Ok(DiffActivity::None),
"Active" => Ok(DiffActivity::Active),
"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
"Const" => Ok(DiffActivity::Const),
"Dual" => Ok(DiffActivity::Dual),
"DualOnly" => Ok(DiffActivity::DualOnly),
"Duplicated" => Ok(DiffActivity::Duplicated),
"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
_ => Err(()),
}
}
}

impl AutoDiffAttrs {
pub fn has_ret_activity(&self) -> bool {
self.ret_activity != DiffActivity::None
}
pub fn has_active_only_ret(&self) -> bool {
self.ret_activity == DiffActivity::ActiveOnly
}

pub fn error() -> Self {
AutoDiffAttrs {
mode: DiffMode::Error,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}
pub fn source() -> Self {
AutoDiffAttrs {
mode: DiffMode::Source,
ret_activity: DiffActivity::None,
input_activity: Vec::new(),
}
}

pub fn is_active(&self) -> bool {
self.mode != DiffMode::Error
}

pub fn is_source(&self) -> bool {
self.mode == DiffMode::Source
}
pub fn apply_autodiff(&self) -> bool {
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
}

pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_ast/src/expand/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
use crate::MetaItem;

pub mod allocator;
pub mod autodiff_attrs;
pub mod typetree;

#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
pub struct StrippedCfgItem<ModId = DefId> {
Expand Down
69 changes: 69 additions & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::fmt;

use crate::expand::{Decodable, Encodable, HashStable_Generic};

#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum Kind {
Anything,
Integer,
Pointer,
Half,
Float,
Double,
Unknown,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct TypeTree(pub Vec<Type>);

impl TypeTree {
pub fn new() -> Self {
Self(Vec::new())
}
pub fn all_ints() -> Self {
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: what's the significance of offset being -1?

Copy link
Contributor Author

@ZuseZ4 ZuseZ4 Sep 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the bane of my existence. :D Enzyme uses a type layout representation which is suboptimal, it should move over to type trie's, but another student and I both gave up on our refactor PRs and the Enzyme core Lead dev also isn't actively working this. Luckily there are some AD users which seem to have the resources to eventually rewrite the whole typetree infrastructure, so I consider it just an implementation detail.

In less dramatic, -1 in Enzyme speech means everywhere.
that is {0:-1: Float} means at index 0 you have a ptr, if you dereference it it will be floats everywhere. Thus * f32.
If you have {-1:int} it means int's everywhere, e.g. [i32; N].
{0:-1:-1 float} then means one pointer at offset 0, if you dereference it there will be only pointers, if you dereference these new pointers they will point to array of floats.
Generally, it allows byte-specific descriptions.

This design has no way of handling recursive datastructures, it skips things that have more than 5 indirections and there are some hacks to make it handle gaps in layouts, as well as other issues with it. If Enzyme is slow at compile time than this is usually the culprit. Also it should be extended at some point to make use of const/mut knowledge, right now that get's lost once we have more than one indirection. To be fair I find it pretty cool that Enzyme is already so extremely fast while leaving some information like here still unused.

The middle-end PR will include tests for typetrees, since that's where we construct them.
I'll add this also to the docs, so I can link to them next time.

}
pub fn int(size: usize) -> Self {
let mut ints = Vec::with_capacity(size);
for i in 0..size {
ints.push(Type {
offset: i as isize,
size: 1,
kind: Kind::Integer,
child: TypeTree::new(),
});
}
Self(ints)
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct FncTree {
pub args: Vec<TypeTree>,
pub ret: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct Type {
pub offset: isize,
pub size: usize,
pub kind: Kind,
pub child: TypeTree,
}

impl Type {
pub fn add_offset(self, add: isize) -> Self {
let offset = match self.offset {
-1 => add,
x => add + x,
};

Self { size: self.size, kind: self.kind, child: self.child, offset }
}
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
4 changes: 4 additions & 0 deletions compiler/rustc_builtin_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ name = "rustc_builtin_macros"
version = "0.0.0"
edition = "2021"


[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(llvm_enzyme)'] }

[lib]
doctest = false

Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_builtin_macros/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ builtin_macros_assert_requires_boolean = macro requires a boolean expression as
builtin_macros_assert_requires_expression = macro requires an expression as an argument
.suggestion = try removing semicolon

builtin_macros_autodiff = autodiff must be applied to function
builtin_macros_autodiff_missing_config = autodiff requires at least a name and mode
builtin_macros_autodiff_mode = unknown Mode: `{$mode}`. Use `Forward` or `Reverse`
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type

builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s
.label = not applicable here
.label2 = not a `struct`, `enum` or `union`
Expand Down
Loading
Loading