Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj!
import MatrixAlgebraKit: heevj!, heevd!, geev!
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback!, eig_pullback!
import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback!, eig_pullback!, svd_pushforward!
using CUDA, CUDA.cuBLAS
using CUDA: i32
using LinearAlgebra
Expand Down Expand Up @@ -213,4 +213,14 @@ function eig_pullback!(ΔA::AnyCuMatrix, A, DV, ΔDV, ind::AnyCuVector; kwargs..
return eig_pullback!(ΔA, A, DV, ΔDV, collect(ind); kwargs...)
end

# have to override this as methods are missing in GPUArrays for the various
# views of Diagonal of ΔA
Comment on lines +216 to +217

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a confusing comment: what exactly is missing?
Might be useful to keep track of this, in case it gets fixed.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's again a situation of mul!(::Diagonal{T, CuVector{T}}, [horrific view of adjoint of view], CuArray) which GPUArrays cannot dispatch onto at all.

function svd_pushforward!(
ΔA::Diagonal{T, <:CuVector{T}}, A, USVᴴ, ΔUSVᴴ, ind = Colon();
rank_atol::Real = MatrixAlgebraKit.default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = MatrixAlgebraKit.default_pullback_rank_atol(USVᴴ[2])
) where {T}
return MatrixAlgebraKit.svd_pushforward!(diagm(diagview(ΔA)), A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
end

end
163 changes: 90 additions & 73 deletions ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl

Large diffs are not rendered by default.

209 changes: 108 additions & 101 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ include("common/safemethods.jl")
include("common/view.jl")
include("common/regularinv.jl")
include("common/matrixproperties.jl")
include("common/utility.jl")

include("yalapack.jl")
include("algorithms.jl")
Expand Down Expand Up @@ -132,6 +133,7 @@ include("pullbacks/polar.jl")
include("pushforwards/polar.jl")
include("pushforwards/eig.jl")
include("pushforwards/eigh.jl")
include("pushforwards/svd.jl")

include("precompile.jl")

Expand Down
16 changes: 16 additions & 0 deletions src/common/utility.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@

function has_equal_storage(A::Diagonal, B::Diagonal)
return diagview(A) === diagview(B)
end
function has_equal_storage(A::AbstractMatrix, B::AbstractMatrix)
return A === B
end

function has_equal_storage(A::Diagonal, B::AbstractVector)
return diagview(A) === B
end
function has_equal_storage(A::AbstractVector, B::Diagonal)
return A === diagview(B)
end
has_equal_storage(A::AbstractMatrix, B::AbstractVector) = false
has_equal_storage(A::AbstractVector, B::AbstractMatrix) = false
6 changes: 3 additions & 3 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ function eig_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
D, V = DV
diagA = diagview(A)
I = sortperm(diagA; by = eig_sortby)
if D === A
if has_equal_storage(A, D)
permute!(diagA, I)
else
diagview(D) .= view(diagA, I)
Expand All @@ -179,8 +179,8 @@ end

function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm)
check_input(eig_vals!, A, D, alg)
Ad = diagview(A)
D === Ad || copy!(D, Ad)
diagA = diagview(A)
has_equal_storage(A, D) || copy!(D, diagA)
sort!(D; by = eig_sortby)
return D
end
Expand Down
4 changes: 4 additions & 0 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,15 @@ function eig_pullback!(
end
return ΔA
end
# Diagonal: do not specialize on `A`, since we may insert `A = nothing` to assert independence of `A` in the implementation
function eig_pullback!(
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)
# TODO:
# If A and ΔA are diagonal, then V is a permutation matrix and so is inv(V) = V'.
# Furthermore, since V̇ is 0, the pullback ΔV cannot contribute and we only have to unpermute ΔD.
ΔA_full = zero!(similar(ΔA, size(ΔA)))
ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
diagview(ΔA) .+= diagview(ΔA_full)
Expand Down
4 changes: 4 additions & 0 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,15 @@ function eigh_pullback!(
ΔA = mul!(ΔA, V * VᴴΔAV, V', 1, 1)
return ΔA
end
# Diagonal: do not specialize on `A`, since we may insert `A = nothing` to assert independence of `A` in the implementation
function eigh_pullback!(
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)
# TODO:
# If A and ΔA are diagonal, then V is a permutation matrix and so is inv(V) = V'.
# Furthermore, since V̇ is 0, the pullback ΔV cannot contribute and we only have to unpermute ΔD.
ΔA_full = zero!(similar(ΔA, size(ΔA)))
ΔA_full = eigh_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
diagview(ΔA) .+= diagview(ΔA_full)
Expand Down
4 changes: 4 additions & 0 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,16 @@ function svd_pullback!(
end
return ΔA
end
# Diagonal: do not specialize on `A`, since we may insert `A = nothing` to assert independence of `A` in the implementation
function svd_pullback!(
ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ, ind = Colon();
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...)
)
# TODO:
# If A and ΔA are diagonal, then U and V are permutation matrices (up to signs/phases).
# Furthermore, since U̇ and V̇ are 0, the pullbacks ΔU and ΔV cannot contribute and we only have to unpermute ΔS.
ΔA_full = zero!(similar(ΔA, size(ΔA)))
ΔA_full = svd_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol, gauge_atol)
diagview(ΔA) .+= diagview(ΔA_full)
Expand Down
12 changes: 8 additions & 4 deletions src/pushforwards/eig.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
function eig_pushforward!(
ΔA, A, DV, ΔDV;
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
degeneracy_atol::Real = default_pullback_rank_atol(DV[1])
)
D, V = DV
ΔD, ΔV = ΔDV
Expand All @@ -12,11 +11,16 @@ function eig_pushforward!(
end
if !iszerotangent(ΔV)
∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol)
mul!(ΔV, V, ∂K, 1, 0)
mul!(ΔV, V, ∂K)
if eltype(V) <: Complex # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, V; dims = 1)
infinitesimal_phases = imag.(ΔV[I] ./ V[I])
ΔV .-= im .* V .* infinitesimal_phases
end
end
return ΔDV
end

function eig_vals_pushforward!(ΔA, A, DV, ΔD; kwargs...)
return eig_pushforward!(ΔA, A, DV, (Diagonal(ΔD), nothing); kwargs...)
return eig_pushforward!(ΔA, A, DV, (diagonal(ΔD), nothing); kwargs...)
end
8 changes: 6 additions & 2 deletions src/pushforwards/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
function eigh_pushforward!(
ΔA, A, DV, ΔDV;
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
degeneracy_atol::Real = default_pullback_rank_atol(DV[1])
)
D, V = DV
ΔD, ΔV = ΔDV
Expand All @@ -13,6 +12,11 @@ function eigh_pushforward!(
if !iszerotangent(ΔV)
∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol)
ΔV = mul!(ΔV, V, ∂K)
if eltype(V) <: Complex # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, V; dims = 1)
infinitesimal_phases = imag.(ΔV[I] ./ V[I])
ΔV .-= im .* V .* infinitesimal_phases
end
end
return (ΔD, ΔV)
end
Expand Down
82 changes: 82 additions & 0 deletions src/pushforwards/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
function svd_pushforward!(
ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon();
rank_atol = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol = default_pullback_rank_atol(USVᴴ[2])
)
U, Smat, Vᴴ = USVᴴ
m, n = size(U, 1), size(Vᴴ, 2)
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
minmn = min(m, n)
S = diagview(Smat)
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
r = svd_rank(S; rank_atol)

U₁ = view(U, :, 1:r)
S₁ = view(S, 1:r)
V₁ᴴ = view(Vᴴ, 1:r, :)

# compact region
V₁ = adjoint(V₁ᴴ)
ΔAV₁ = ΔA * V₁
UᴴΔAV₁ = U₁' * ΔAV₁
if !iszerotangent(ΔS)
zero!(ΔS) # make off-diagonal entries zero in case of full ΔS (svd_full!)
ΔS₁ = view(diagview(ΔS), 1:r)
ΔS₁ .= real.(diagview(UᴴΔAV₁))
end
if !iszerotangent(ΔU) || !iszerotangent(ΔVᴴ)
hUᴴΔAV₁ = inv_safe.(transpose(S₁) .- S₁, degeneracy_atol) .* project_hermitian(UᴴΔAV₁)
aUᴴΔAV₁ = inv_safe.(transpose(S₁) .+ S₁, degeneracy_atol) .* project_antihermitian(UᴴΔAV₁)
if !iszerotangent(ΔU)
ΔU₁ = view(ΔU, :, 1:r)
K̇ = hUᴴΔAV₁ + aUᴴΔAV₁
mul!(ΔU₁, U₁, K̇)
if m > r
ΔAV₁ = mul!(ΔAV₁, U₁, UᴴΔAV₁, -1, 1)
ΔU₁ .+= ΔAV₁ ./ transpose(S₁)
end
if size(U, 2) > r # these columns of U are undetermined, but U' * U̇ should be antihermitian
U₂ = view(U, :, (r + 1):size(U, 2))
ΔU₁ᴴU₂ = ΔU₁' * U₂
ΔU₂ = view(ΔU, :, (r + 1):size(U, 2))
mul!(ΔU₂, U₁, ΔU₁ᴴU₂, -1, 0)
end
end
if !iszerotangent(ΔVᴴ)
ΔV₁ᴴ = view(ΔVᴴ, 1:r, :)
Ṁ = hUᴴΔAV₁ - aUᴴΔAV₁
mul!(ΔV₁ᴴ, Ṁ', V₁ᴴ)
if n > r
UᴴΔA₁ = U₁' * ΔA
UᴴΔA₁ = mul!(UᴴΔA₁, UᴴΔAV₁, V₁ᴴ, -1, 1)
ΔV₁ᴴ .+= S₁ .\ UᴴΔA₁
end
if size(Vᴴ, 1) > r # these rows of Vᴴ are undetermined, but V * V̇ should be antihermitian
V₂ᴴ = view(Vᴴ, (r + 1):size(Vᴴ, 1), :)
V₂ᴴΔV₁ = V₂ᴴ * ΔV₁ᴴ'
ΔV₂ᴴ = view(ΔVᴴ, (r + 1):size(Vᴴ, 1), :)
mul!(ΔV₂ᴴ, V₂ᴴΔV₁, V₁ᴴ, -1, 0)
end
end
if eltype(U) <: Complex && !iszerotangent(ΔU) && !iszerotangent(ΔVᴴ) # fix gauge for `gaugefix!` compatibility
_, I = findmax(abs, U₁; dims = 1)
infinitesimal_phases = imag.(ΔU₁[I] .* inv_safe.(U₁[I]))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be suprised if that is needed or makes a difference. U₁[I] is the maximum element of every column of U₁. As U₁ is an isometric matrix, all of its columns have norm 1, and therefore, the largest element needs to be at least 1/sqrt(m), in magnitude, with m the number of rows. Typically, it will be larger.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be quite surprised too but I'm otherwise very confused where the NaN are emerging from (the ./transpose(S1) line has the same objection). I'll try stepping through the pushfoward to see if I can find the culprit.

ΔU₁ .-= im .* U₁ .* infinitesimal_phases
ΔV₁ᴴ .+= im .* transpose(infinitesimal_phases) .* V₁ᴴ
end
end
return (ΔU, ΔS, ΔVᴴ)
end

# TODO
#=function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...)
end=#

function svd_vals_pushforward!(
ΔA, A, USVᴴ, ΔS, ind = Colon();
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
)
ΔUSVᴴ = (nothing, diagonal(ΔS), nothing)
return svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
end
4 changes: 2 additions & 2 deletions test/enzyme/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite
Expand All @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
if !is_buildkite
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
AT = Diagonal{T, Vector{T}}
m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
m == n && TestSuite.test_enzyme_svd(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
46 changes: 38 additions & 8 deletions test/testsuite/enzyme/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,78 @@ function test_enzyme_svd(T::Type, sz; kwargs...)
end
end

"""
test_enzyme_svd_compact(T, sz; rng, atol, rtol)

Test the Enzyme forward- and reverse-mode AD rule for `svd_compact` and its in-place variant.
"""
function test_enzyme_svd_compact(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "svd_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "svd_compact: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = instantiate_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(svd_compact, A)
USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A)
test_reverse(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_reverse(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_reverse(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_forward(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
end
end

"""
test_enzyme_svd_full(T, sz; rng, atol, rtol)

Test the Enzyme forward- and reverse-mode AD rule for `svd_full` and its in-place variant. The
gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent.
"""
function test_enzyme_svd_full(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "svd_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "svd_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = instantiate_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(svd_full, A)
USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A)
test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_reverse(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
test_reverse(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm)
if size(A, 1) == size(A, 2) # finite differences check for free component is very finicky
test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
end
end
end

"""
test_enzyme_svd_vals(T, sz; rng, atol, rtol)

Test the Enzyme forward- and reverse-mode AD rule for `svd_vals` and its in-place variant.
"""
function test_enzyme_svd_vals(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
fdm = enzyme_fdm(T)
)
return @testset "svd_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
return @testset "svd_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
A = instantiate_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(svd_vals, A)
S, ΔS = ad_svd_vals_setup(A)
test_reverse(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
test_reverse(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
test_reverse(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm)
test_forward(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm)
test_forward(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm)
end
end

"""
test_enzyme_svd_trunc(T, sz; rng, atol, rtol)

Test the Enzyme reverse-mode AD rules for `svd_trunc`, `svd_trunc_no_error`, and their
in-place variants, over a range of truncation ranks and a tolerance-based truncation.
"""
function test_enzyme_svd_trunc(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T),
Expand All @@ -64,15 +94,15 @@ function test_enzyme_svd_trunc(
trunc = truncrank(r)
truncalg = TruncatedAlgorithm(alg, trunc)
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
end
@testset "trunctol" begin
S = svd_vals(A, alg)
trunc = trunctol(atol = maximum(S) / 2)
truncalg = TruncatedAlgorithm(alg, trunc)
USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm)
end
end
Expand Down
10 changes: 8 additions & 2 deletions test/testsuite/mooncake/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,15 @@ function test_mooncake_eig_vals(
rng, eig_vals, A, alg;
output_tangent, atol, rtol
)
if T <: Diagonal{<:Complex}
if A isa Diagonal{<:Complex}
A2 = copy(A)
Mooncake.TestUtils.test_rule(
rng, eig_vals!, A2, copy(A2.diag), alg;
output_tangent, atol, rtol
)
A2 = copy(A)
Mooncake.TestUtils.test_rule(
rng, eig_vals!, A, A.diag, alg;
rng, eig_vals!, A2, A2.diag, alg;
output_tangent, atol, rtol
)
end
Expand Down
Loading
Loading