@@ -10,8 +10,6 @@ PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)
1010
1111(f:: PartialsFn{T} )(i) where {T} = partials (T, f. dual, i)
1212
13- _take (itr, N:: Integer ) = Iterators. take (itr, min (length (itr), N))
14-
1513function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
1614 seed:: Partials{N,V} ) where {T,V,N}
1715 idxs = collect (ForwardDiff. structural_eachindex (duals, x))
2119
2220function ForwardDiff. seed! (duals:: AbstractGPUArray{Dual{T,V,N}} , x,
2321 seeds:: NTuple{N,Partials{N,V}} ) where {T,V,N}
24- idxs = collect (_take (ForwardDiff. structural_eachindex (duals, x), N))
22+ idxs = collect (Iterators . take (ForwardDiff. structural_eachindex (duals, x), N))
2523 duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
2624 return duals
2725end
@@ -38,7 +36,7 @@ function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
3836 seeds:: NTuple{N,Partials{N,V}} , chunksize) where {T,V,N}
3937 offset = index - 1
4038 idxs = collect (
41- _take (Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset), chunksize)
39+ Iterators . take (Iterators. drop (ForwardDiff. structural_eachindex (duals, x), offset), chunksize)
4240 )
4341 duals[idxs] .= Dual {T,V,N} .(view (x, idxs), getindex .(Ref (seeds), 1 : length (idxs)))
4442 return duals
4846function ForwardDiff. extract_gradient! (:: Type{T} , result:: AbstractGPUArray ,
4947 dual:: Dual ) where {T}
5048 fn = PartialsFn {T} (dual)
51- idxs = collect (_take (ForwardDiff. structural_eachindex (result), npartials (dual)))
49+ idxs = collect (Iterators . take (ForwardDiff. structural_eachindex (result), npartials (dual)))
5250 result[idxs] .= fn .(1 : length (idxs))
5351 return result
5452end
@@ -58,7 +56,7 @@ function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray
5856 fn = PartialsFn {T} (dual)
5957 offset = index - 1
6058 idxs = collect (
61- _take (Iterators. drop (ForwardDiff. structural_eachindex (result), offset), chunksize)
59+ Iterators . take (Iterators. drop (ForwardDiff. structural_eachindex (result), offset), chunksize)
6260 )
6361 result[idxs] .= fn .(1 : length (idxs))
6462 return result
0 commit comments