Complex Backend#3608
Conversation
- Add Complex as first-class TensorKind alongside Float, Int, Bool - Add ComplexTensorPrimitive and ComplexElem to Backend trait - Add complex tensor type aliases and exports
- Add ComplexTensorPrimitive support to NdArray backend - Implement complex arithmetic and transcendental functions in NdArray backend - Add autodiff backend wrapper for complex tensors - Begin enabling support across backend ecosystem
- Add high-level Tensor<B, D, Complex> API with BasicOps and Numeric traits - Add complex-specific methods: conj(), real(), imag(), magnitude(), phase() - Add creation utilities: from_parts(), from_polar(), zeros(), ones() - Start adding test suite covering operations
- Remove non-existent testgen_complex\!() macro call that was causing compilation errors - Add ComplexTensorOps implementations for all backends (tch, candle, cubecl, fusion, router) - Fix complex tensor assertion logic in CubeCL backend to avoid Float trait requirements - Add missing transcendental functions (exp, log, sin, cos, tan, sqrt, powc) to Complex tensor API
|
@laggui
alternatively, If I make some more stuff in burn tensor public (such as kind), then I can lift almost everything out of burn-tensor |
laggui
left a comment
There was a problem hiding this comment.
@skewballfox for now, I think we should treat the complex backend as an extension. Thus, almost all types, traits and impl should only live in the burn-complex extension at this time.
The only part that should be added to burn-tensor is the DType variant, since we don't currently have a way to add/support custom dtypes anyway.
I believe the rest can easily live as an extension in a separate crate.
pub type ComplexTensor<B> = <B as ComplexTensorBackend>::ComplexTensorPrimitive;
pub trait ComplexTensorBackend: Backend {
/// The inner backend type.
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem>;
/// Tensor primitive to be used for all complex operations.
type ComplexTensorPrimitive: TensorMetadata + 'static;
/// Returns the real part of a complex tensor.
fn real(tensor: ComplexTensor<Self>) -> FloatTensor<Self::InnerBackend>;
/// Returns the imaginary part of a complex tensor.
fn imag(tensor: ComplexTensor<Self>) -> FloatTensor<Self::InnerBackend>;
fn to_complex(tensor: FloatTensor<Self::InnerBackend>) -> ComplexTensor<Self>;
}
/// A type-level representation of the kind of a complex tensor.
#[derive(Clone, Debug)]
pub struct Complex;
impl<B: ComplexTensorBackend> TensorKind<B> for Complex {
type Primitive = B::ComplexTensorPrimitive;
fn name() -> &'static str {
"Complex"
}
}You can easily implement the tensor ops traits for the Complex type, e.g.
impl<B: ComplexTensorBackend> BasicOps<B> for Complex {
// ...
}For the element type, I think the make_elem! macro will not work because ToElement does not have to_complex, but that is fine. The macro was mostly to avoid repeating the implementation, we can either implement the Element trait manually for these types or make the macro a bit more flexible for custom external types. Maybe that will require adding ToComplex (in the complex crate), and implement it for types that implement ToElement, so we can convert types <> complex.
Then, for the concrete implementations we can have feature flags similar to burn-vision
#[cfg(feature = "ndarray")]
mod ndarray {
use crate::ComplexTensorBackend;
use burn_ndarray::{
FloatNdArrayElement, IntNdArrayElement, NdArray, NdArrayTensorFloat, QuantElement,
};
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> ComplexTensorBackend
for NdArray<E, I, Q>
{
// ...
}
}So we can limit the extension incrementally.
Eventually, we might move it as a core feature/trait, but I believe starting as an extension is the right approach to limit the scope.
CC'ing @nathanielsimard in case you have other thoughts on this subject.
|
@laggui I'm currently running into issues with the Complex{32,64} needing to implement the element trait for BasicOps. I can't implement elementConversion and a few others I've duplicated in I think moving complex element into burn-tensor and leaving everything else in burn-complex is probably the right approach, but I'm open to other suggestions. An alternative is create a new trait that would be shared by Element and ComplexElement and make that the requirement for basic ops, but that seems like the wrong approach to take |
|
This PR has been marked as stale because it has not been updated for over a month |
|
still working on this. Just had a busy few weekends. I'm hoping to work on this a little this weekend |
|
@laggui I think I have a full implementation for split tensors and interleaved for burn-flex, I'm currently trying to rework the test prakash started. I'm having trouble getting one specific method to work with the types I was using to refer to the inner backend's int tensor primitive: While it's not quite ready, the only thing I think that's left for a first pass is to fix the issue with the ambigious/mismatched associated type blocking powi and to write the test for both the interleaved and complex versions. If you want to start reviewing, those last few pieces should be done next weekend or the one after that. |
|
@skewballfox sounds good, will try to review this on Wednesday 🙏 |
antimora
left a comment
There was a problem hiding this comment.
Took a pass over the diff. Great progress on the scope, and the decorator direction is the right call. A few things grouped by theme so the inline comments aren't swimming in context.
Design (responding to the goals in the PR description)
Goal 3 (ComplexBackend is not a supertrait of Backend) is satisfied syntactically, but in practice BasicOps / Numeric for ComplexKind in base.rs require C: ComplexTensorBackend<InnerBackend = C> + Backend wherever they're used through Tensor<B, Complex>. That effectively excludes SplitBackend<B>, which is PhantomData<B> and never implements Backend. So the "any backend via SplitBackend" story doesn't flow to the public Tensor surface today. Worth deciding whether that constraint belongs on the trait itself or whether SplitBackend needs to implement Backend (even as a thin passthrough).
The Layout trait is currently carrying one associated type and no methods; DefaultComplexOps ends up with two near-duplicate impls for the two layouts. Until there's a layout-specialized op that actually needs the split (butterfly/FFT), it might be worth folding Layout into ComplexTensorBackend and reintroducing it when a concrete need shows up. As it stands the separation is paying ceremonial cost without a payoff yet.
There are now two parallel Complex types: burn_tensor::Complex32/64 (concrete, derived Pod) and burn_complex::Complex<E> (generic, hand-rolled unsafe impl Pod). They're sound, but the duplication comes back to bite in the element/dtype traits (the ToComplex* vs ToElement split, ElementLimits on Complex<E>, etc.). Picking one and letting DType::Complex32/64 map through it would cut a lot of surface area.
Backing laggui's earlier recommendation: leaving only the DType variant in burn-tensor and pulling the rest into burn-complex still looks right. ComplexKind / BasicOps / Numeric impls currently living in burn-tensor are the main bit keeping the crate boundary fuzzy.
Correctness hotspots
A few real bugs below (inline). The big one is complex_into_imag_data in burn-flex calling the real-extraction helper. There are also several todo!() arms reachable from normal flows (safetensors save, TensorData::assert_approx_eq, fusion IR, Numeric::powi, complex_powf when rhs dtype doesn't match) which would panic rather than fail gracefully.
Tests
The 908-line tests/ops.rs doesn't currently run in any harness: the #[testgen(complex)] attribute is commented out, the export-tests feature is commented out in Cargo.toml, and TestBackend is never defined. cargo test -p burn-complex runs only the four small unit tests in base/element.rs. Worth wiring this up first, because most of the bugs below are the kind a half-decent test suite catches. Also: the SplitBackend code path (~900 LOC) has no tests at all, so the layout-independence claim is unverified end-to-end. Conversion helpers in utils.rs have no tests either, which is exactly where the imag bug came from.
Minor / housekeeping
Commented-out stubs in router/fusion/cubecl/remote leave a false impression those backends work with complex. Consider deleting them or gating behind a clearly-unfinished marker. base.rs also has a fair amount of author-scratch comments (//Note:, //TODO: without issue links, a handful of commented-out methods) that would be worth cleaning before a real review pass.
Happy to go deeper on any of these.
|
|
||
| fn complex_exp(tensor: ComplexTensor<SplitBackend<B>>) -> ComplexTensor<SplitBackend<B>> { | ||
| // formula: e^(a + bi) = e^a * (cos(b) + i*sin(b)) = from_polar(e^a, b) | ||
| //TODO: add the checks for corner cases +∞, -∞, and NaN |
There was a problem hiding this comment.
The commented-out URL points to a fork (skewballfox/burn), which will rot. If the TODO is worth keeping, converting it into an upstream issue link is more durable.
There was a problem hiding this comment.
I'll create an issue once this is ready to merge, or try to go ahead and fix it after I write the test
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
…lated traits and tests Co-authored-by: Copilot <copilot@github.com>
…missing doc strings
…ch when testing burn-complex Co-authored-by: Copilot <copilot@github.com>
| Self::Elem: Element, | ||
| { | ||
| /// The type of the integer tensor primitive associated with this numeric kind. | ||
| type IntTensor: TensorKind<B>; |
There was a problem hiding this comment.
@laggui so I was having trouble resolving non-complex primitives for complex kinds in the numeric impl. From what I understand, rust resolves traits by path, so even doing something like making a supertrait of backendtypes with complex primitives, and making it backend types a supertrait of that supertrait, because the self::Primitive is defined along a different level of that path, it coerces it the the complex kinds primitive rather than BackendTypes primitive. I tried multiple methods to work around that other than the super trait, but ultimately without equality constraints on associated traits (currently unstable), I think we'll have to pin the type for methods that need to resolve to a specific primitive, like I did above.
If there is an alternative I'd love to hear it.
There was a problem hiding this comment.
That's a code smell...
It's not clear why this would required for the ComplexTensorBackend when you also have + BackendTypes bound. But I wouldn't introduce an explicit associated type on Numeric to mark the IntTensor kind.
I also noticed you moved backend methods to BackendTypes, I thought this was only introduced in #4868 for associated types?
There was a problem hiding this comment.
That's a code smell...
It's not clear why this would required for the ComplexTensorBackend when you also have + BackendTypes bound. But I wouldn't introduce an explicit associated type on Numeric to mark the IntTensor kind.
honestly, I agree, but the only other solution that might work is associated type defaults, which is unstable, and that's if the issue is the trait resolution path (option 2 below). If you clone the branch locally and revert to 84514e5, and make it so that CBT or complexBackend is how you are trying to supply the self::primitive, you basically run into this
mismatched types
expected associated type `<C as Backend>::IntTensorPrimitive`
found associated type `<<C as ComplexTensorBackend>::Layout as base::Layout>::ComplexTensorPrimitive`
an associated type was expected, but a different one was found
and that's whether you have ComplexTensorPrimitive exist in CBT, Layout or both. Trait resolution works by path. Self::Primitive always resolves to the complex tensor. I think this means either one or two things:
- for other backends, rhs for powi isn't resolving to
IntTensorPrimitive, it's resolving to whatever the primitive is for that tensorkind, even if the underlying dtype wasn't aligned with the primitive it was being resolved to. If that's the case then the type mapping in backend types is redundant. - the issue is with the layering, if Int/Float, whatev are actually disctinct mappings at compile time, and then this is going to pose an issue for any backend decorator introducing a new tensor kind, or new underlying dtype. The options are either provide a mapping at the trait level for the ops that need one of the inputs to be a specific type, or provide a mapping at the trait level for the default primitive, and then use backend types for the inputs.
I also noticed you moved backend methods to BackendTypes, I thought this was only introduced in #4868 for associated types?
check the other changes in this commit. we had a few functions that needed device information in addition to type mapping. It didn't pop up in the PR because I wasn't able to test it on a decorator that didn't already implement backend. I can move them into a separate trait, and then propagate that to the parent trait requirements.
|
|
||
| fn complex_exp(tensor: ComplexTensor<SplitBackend<B>>) -> ComplexTensor<SplitBackend<B>> { | ||
| // formula: e^(a + bi) = e^a * (cos(b) + i*sin(b)) = from_polar(e^a, b) | ||
| //TODO: add the checks for corner cases +∞, -∞, and NaN |
There was a problem hiding this comment.
I'll create an issue once this is ready to merge, or try to go ahead and fix it after I write the test
Co-authored-by: Copilot <copilot@github.com>
It's a bit early but could definitely use feedback on what works and doesn't in terms of the design
The goals are:
Checklist
cargo run-checkscommand has been executed.Related Issues/PRs
Changes
main changes in relation to the goals are:
burn-complexcrate that will house a lot of the shared definitions. I'm guessing most of the stuff other than theComplexTensorBackendtrait and dtype for complex numbers will be moved hereComplexLayoutthat will be implemented on unit structs that indicate what type of complex layout is in use for an implementation, which allows implementors to define functions and traits only meant to be used for a specific data layout.Testing
TODO
Notes
My current plan is to get all ops implemented for split and burn-flex, write test for the ops, and then once it's almost ready to merge, stash the ndarray implementation and make that a separate PR. It's just way easier to implement for flex. Ndarray macros do not spark joy.