-
Notifications
You must be signed in to change notification settings - Fork 350
Tensordot implementation #1545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Tensordot implementation #1545
Conversation
Adds a full implementation of `tensordot` for n-dimensional arrays, supporting both numeric and paired axis specifications via the new `AxisSpec` enum. The design mirrors NumPy’s `tensordot` behaviour while integrating cleanly with ndarray’s trait-based approach for the ``dot`` product and existing shape and stride logic. All internal reshaping and permutation operations use `unwrap` and `expect` with explicit safety reasoning: each call is guarded by dimension and axis validation, ensuring panics can only occur under invalid `AxisSpec` input. Documentation and inline comments describe these invariants and the exact failure conditions. Includes tests verifying correct contraction for both paired and integer axis modes.
| /// # use ndarray::linalg::AxisSpec; | ||
| /// let axes = AxisSpec::Pair(vec![1, -1], vec![0, 2]); | ||
| /// ``` | ||
| Pair(Vec<isize>, Vec<isize>), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know both vectors must be of equal length, but must the number of elements be equal to the dimension?
I always try to reduce the number of memory allocations. I wonder if there's a better way than to accept 2 vectors. If the maximum number of elements = the dimension, we might be able to use constant size arrays. Although there are 2 dimensions here. Is one of the dimension always >= to the other?
src/linalg/impl_linalg.rs
Outdated
| /// - The computed product of reshaped dimensions does not equal the | ||
| /// array’s total element count (which would indicate internal logic error). | ||
| #[track_caller] | ||
| fn tensordot(&self, rhs: &Rhs, axes: AxisSpec) -> Self::Output; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have never used tensordot. Is this a kind of function that gets called often with the same arguments? If so, wa might want to use &AxisSpec instead.
src/linalg/impl_linalg.rs
Outdated
|
|
||
| let c2 = a2.dot(&b2); | ||
|
|
||
| let mut out_shape: Vec<usize> = notin_a.iter().map(|&ax| a.shape()[ax]).collect(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're allocating 3 vectors in this function. Can't you re-use one of them for out_shape?
Updates ``tensordot_impl`` by removing one vector allocation, eliminating repeated axis-membership scans, and reducing shape-index lookups. Also switches axes_a/axes_b to borrowed inputs for potential reuse upstream. Replace ``notin_a`` + ``clone`` with direct construction of ``out_shape``, removing 1 allocation and 1 clone. Allocation count is now: - ``is_contracted_a`` → ``O(ndim(a))`` - is_contracted_b → ``O(ndim(b))`` - newaxes_a → ``O(ndim(a))`` - newaxes_b → ``O(ndim(b))`` - out_shape → ``O(ndim(a) + ndim(b) - contracted)`` All sizes are exactly determined by the axis mask and do not depend on runtime data beyond shape rank. Precompute boolean membership arrays for contracted axes, replacing multiple ``iter().any()`` scans with O(1) lookups. Cache ``a.shape()`` and ``b.shape()`` slices to avoid repeated indexing. Update signature to accept borrowed axis lists, allowing the caller to reuse them without moving.
Tensordot implementation Issue #1517
Adds a full implementation of
tensordotfor n-dimensional arrays,supporting both numeric and paired axis specifications via the new
AxisSpecenum. The design mirrors NumPy’stensordotbehaviour whileintegrating cleanly with ndarray’s trait-based approach for the
dotproduct and existing shape and stride logic.
All internal reshaping and permutation operations use
unwrapandexpectwith explicit safety reasoning: each call is guarded bydimension and axis validation, ensuring panics can only occur under
invalid
AxisSpecinput. Documentation and inline commentsdescribe these invariants and the exact failure conditions.
Includes tests verifying correct contraction for
both paired and integer axis modes.