diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 558fcec45..442ee0133 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -8,7 +8,7 @@ using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, Abstract using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, gesvd!, gesvdp!, gesvdr!, gesvdj! import MatrixAlgebraKit: heevj!, heevd!, geev! -import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback!, eig_pullback! +import MatrixAlgebraKit: _gpu_Xgesvdr!, _sylvester, svd_rank, svd_pullback!, eigh_pullback!, eig_pullback!, svd_pushforward! using CUDA, CUDA.cuBLAS using CUDA: i32 using LinearAlgebra @@ -213,4 +213,14 @@ function eig_pullback!(ΔA::AnyCuMatrix, A, DV, ΔDV, ind::AnyCuVector; kwargs.. return eig_pullback!(ΔA, A, DV, ΔDV, collect(ind); kwargs...) end +# have to override this as methods are missing in GPUArrays for the various +# views of Diagonal of ΔA +function svd_pushforward!( + ΔA::Diagonal{T, <:CuVector{T}}, A, USVᴴ, ΔUSVᴴ, ind = Colon(); + rank_atol::Real = MatrixAlgebraKit.default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = MatrixAlgebraKit.default_pullback_rank_atol(USVᴴ[2]) + ) where {T} + return MatrixAlgebraKit.svd_pushforward!(diagm(diagview(ΔA)), A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol) +end + end diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 343bd9681..1ec15a8b6 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -1,13 +1,14 @@ module MatrixAlgebraKitEnzymeExt using MatrixAlgebraKit -using MatrixAlgebraKit: copy_input, initialize_output, zero! +using MatrixAlgebraKit: copy_input, initialize_output, zero!, has_equal_storage using MatrixAlgebraKit: diagview, inv_safe, truncate using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: eig_pushforward!, eigh_pushforward!, eig_vals_pushforward!, eigh_vals_pushforward! using MatrixAlgebraKit: svd_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: svd_pushforward!, svd_vals_pushforward! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using Enzyme @@ -65,15 +66,13 @@ for (f, pb) in ( arg::Annotation{Tuple{TA, TB}}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TA, TB} - # A is overwritten in the primal, but NOT used in the pullback, - # so we do not need to cache it. This may change if future pullbacks - # depend directly on A! + # A is overwritten in the primal, but not used in the pullback, so we do not need to cache it. ret = func.val(A.val, arg.val, alg.val) # if arg.val === ret, the annotation must be Duplicated or DuplicatedNoNeed # if arg isa Const, ret may still be modified further down the call graph so we should # copy it to protect ourselves - A_is_arg1 = !isa(A, Const) && A.val === arg.val[1] - A_is_arg2 = !isa(A, Const) && A.val === arg.val[2] + A_is_arg1 = !isa(A, Const) && has_equal_storage(A.val, arg.val[1]) + A_is_arg2 = !isa(A, Const) && has_equal_storage(A.val, arg.val[2]) A_is_arg = A_is_arg1 || A_is_arg2 cache_arg = arg.val !== ret || A_is_arg || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing dret = if EnzymeRules.needs_shadow(config) && ((TA == Nothing && TB == Nothing) || isa(arg, Const)) @@ -96,13 +95,11 @@ for (f, pb) in ( alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TA, TB} cache_arg, darg = cache - # A is NOT used in the pullback, so we assign Aval = nothing - # to trigger an error in case the pullback is modified to directly - # use A (so that whoever does this is forced to handle caching A - # appropriately here) + # A is not used in the pullback; since we have destroyed A, we insert Aval = nothing + # to trigger an error in case the pushfward is modified to directly use A Aval = nothing - A_is_arg1 = !isa(A, Const) && A.dval === arg.dval[1] - A_is_arg2 = !isa(A, Const) && A.dval === arg.dval[2] + A_is_arg1 = !isa(A, Const) && has_equal_storage(A.dval, arg.dval[1]) + A_is_arg2 = !isa(A, Const) && has_equal_storage(A.dval, arg.dval[2]) A_is_arg = A_is_arg1 || A_is_arg2 argval = something(cache_arg, arg.val) if !isa(A, Const) @@ -134,8 +131,8 @@ for (f, pf) in ( arg::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} - A_is_arg1 = !isa(A, Const) && A.val === arg.val[1] - A_is_arg2 = !isa(A, Const) && A.val === arg.val[2] + A_is_arg1 = !isa(A, Const) && has_equal_storage(A.val, arg.val[1]) + A_is_arg2 = !isa(A, Const) && has_equal_storage(A.val, arg.val[2]) A_is_arg = A_is_arg1 || A_is_arg2 $f(A.val, arg.val, alg.val) if !isa(A, Const) && !isa(arg, Const) @@ -209,33 +206,25 @@ for (f, pb) in ( end end -for f in (:svd_compact!, :svd_full!) +for f! in (:svd_compact!, :svd_full!) @eval begin function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof($f)}, + func::Const{typeof($f!)}, ::Type{RT}, A::Annotation, USVᴴ::Annotation{Tuple{TU, TS, TVᴴ}}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TU, TS, TVᴴ} - # A is overwritten in the primal, but NOT used in the pullback, - # so we do not need to cache it. This may change if future pullbacks - # depend directly on A! + # A is overwritten in the primal, but not used in the pullback, so we do not need to cache it. ret = func.val(A.val, USVᴴ.val, alg.val) # if USVᴴ.val == ret, the annotation must be Duplicated or DuplicatedNoNeed # if USVᴴ isa Const, ret may still be modified further down the call graph so we should # copy it to protect ourselves cache_USVᴴ = (USVᴴ.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing # the USVᴴ may be nothing for eltypes handled by GenericLinearAlgebra - dret = if EnzymeRules.needs_shadow(config) && ((TU == TS == TVᴴ == Nothing) || isa(USVᴴ, Const)) - dU = zero(ret[1]) - # special casing `Diagonal` seems to be necessary due to Enzyme's type analysis - dS = $(f == svd_compact!) ? Diagonal(zero(ret[2].diag)) : zero(ret[2]) - dVᴴ = zero(ret[3]) - (dU, dS, dVᴴ) - elseif EnzymeRules.needs_shadow(config) - USVᴴ.dval + dret = if EnzymeRules.needs_shadow(config) + ((TU == TS == TVᴴ == Nothing) || isa(USVᴴ, Const)) ? make_zero(ret) : USVᴴ.dval else nothing end @@ -244,7 +233,7 @@ for f in (:svd_compact!, :svd_full!) end function EnzymeRules.reverse( config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof($f)}, + func::Const{typeof($f!)}, ::Type{RT}, cache, A::Annotation, @@ -252,10 +241,8 @@ for f in (:svd_compact!, :svd_full!) alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} cache_USVᴴ, dUSVᴴ = cache - # A is NOT used in the pullback, so we assign Aval = nothing - # to trigger an error in case the pullback is modified to directly - # use A (so that whoever does this is forced to handle caching A - # appropriately here) + # A is not used in the pullback; since we have destroyed A, we insert Aval = nothing + # to trigger an error in case the pushfward is modified to directly use A Aval = nothing USVᴴval = something(cache_USVᴴ, USVᴴ.val) if !isa(A, Const) @@ -264,6 +251,30 @@ for f in (:svd_compact!, :svd_full!) !isa(USVᴴ, Const) && make_zero!(USVᴴ.dval) return (nothing, nothing, nothing) end + function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation{TA}, + USVᴴ::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + $f!(A.val, USVᴴ.val, alg.val) + A_is_arg = !isa(A, Const) && has_equal_storage(A.dval, USVᴴ.dval[2]) + if !isa(A, Const) + !isa(USVᴴ, Const) && svd_pushforward!(A.dval, A.val, USVᴴ.val, USVᴴ.dval) + !A_is_arg && make_zero!(A.dval) + end + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return USVᴴ + elseif EnzymeRules.needs_primal(config) + return USVᴴ.val + elseif EnzymeRules.needs_shadow(config) + return USVᴴ.dval + else + return nothing + end + end end end @@ -275,20 +286,13 @@ function EnzymeRules.augmented_primal( USVᴴ::Annotation{Tuple{TU, TS, TVᴴ}}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TU, TS, TVᴴ} - # A is overwritten in the primal, but NOT used in the pullback, - # so we do not need to cache it. This may change if future pullbacks - # depend directly on A! + # A is overwritten in the primal, but not used in the pullback, so we do not need to cache it. ret = svd_compact!(A.val, USVᴴ.val, alg.val.alg) cache_USVᴴ = (USVᴴ.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, ret, alg.val.trunc) primal = EnzymeRules.needs_primal(config) ? USVᴴ′ : nothing - # This creates new output shadow matrices, we use USVᴴ′ to ensure the - # eltypes and dimensions are correct. - # These new shadow matrices are "filled in" with the accumulated - # results from earlier in reverse-mode AD after this function exits - # and before `reverse` is called. dret = if EnzymeRules.needs_shadow(config) - (zero(USVᴴ′[1]), Diagonal(zero(USVᴴ′[2].diag)), zero(USVᴴ′[3])) + make_zero(USVᴴ′) else nothing end @@ -305,10 +309,8 @@ function EnzymeRules.reverse( alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} cache_USVᴴ, dUSVᴴ, ind = cache - # A is NOT used in the pullback, so we assign Aval = nothing - # to trigger an error in case the pullback is modified to directly - # use A (so that whoever does this is forced to handle caching A - # appropriately here) + # A is not used in the pullback; since we have destroyed A, we insert Aval = nothing + # to trigger an error in case the pullback is modified to directly use A Aval = nothing USVᴴval = something(cache_USVᴴ, USVᴴ.val) if !isa(A, Const) @@ -330,15 +332,13 @@ for (f, trunc_f, full_f, pb) in ( DV::Annotation{Tuple{TA, TB}}, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TA, TB} - # A is overwritten in the primal, but NOT used in the pullback, - # so we do not need to cache it. This may change if future pullbacks - # depend directly on A! + # A is overwritten in the primal, but not used in the pullback, so we do not need to cache it. ret = $full_f(A.val, DV.val, alg.val.alg) cache_DV = (DV.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy.(ret) : nothing DV′, ind = truncate($trunc_f, ret, alg.val.trunc) primal = EnzymeRules.needs_primal(config) ? DV′ : nothing dret = if EnzymeRules.needs_shadow(config) - (Diagonal(zero(diagview(DV′[1]))), zero(DV′[2])) + make_zero(DV′) else nothing end @@ -354,10 +354,8 @@ for (f, trunc_f, full_f, pb) in ( alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT} cache_DV, dDVtrunc, ind = cache - # A is NOT used in the pullback, so we assign Aval = nothing - # to trigger an error in case the pullback is modified to directly - # use A (so that whoever does this is forced to handle caching A - # appropriately here) + # A is not used in the pullback; since we have destroyed A, we insert Aval = nothing + # to trigger an error in case the pullback is modified to directly use A Aval = nothing DVval = something(cache_DV, DV.val) if !isa(A, Const) @@ -387,9 +385,7 @@ for (f!, f_full!, pb!, pf!) in ( D::Annotation{TD}, alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TD} - # A is overwritten in the primal, but NOT used in the pullback, - # so we do not need to cache it. This may change if future pullbacks - # depend directly on A! + # A is overwritten in the primal, but not used in the pullback, so we do not need to cache it. nD, V = $f_full!(A.val, alg.val) ret = TD == Nothing ? diagview(nD) : copy!(D.val, diagview(nD)) cache_D = (D.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing @@ -411,12 +407,10 @@ for (f!, f_full!, pb!, pf!) in ( ) where {RT, TA} cache_D, dD, V = cache Dval = something(cache_D, D.val) - # A is NOT used in the pullback, so we assign Aval = nothing - # to trigger an error in case the pullback is modified to directly - # use A (so that whoever does this is forced to handle caching A - # appropriately here) + # A is not used in the pullback; since we have destroyed A, we insert Aval = nothing + # to trigger an error in case the pullback is modified to directly use A Aval = nothing - A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === D.dval + A_is_arg = !isa(A, Const) && has_equal_storage(A.dval, D.dval) if !isa(A, Const) ΔA = A_is_arg ? make_zero(A.dval) : A.dval $pb!(ΔA, Aval, (Diagonal(Dval), V), dD) @@ -433,12 +427,12 @@ for (f!, f_full!, pb!, pf!) in ( D::Annotation, alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TA} - A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === D.dval + A_is_arg = !isa(A, Const) && has_equal_storage(A.dval, D.dval) DV = $f_full!(A.val, alg.val) Dval, V = DV if !isa(A, Const) && !isa(D, Const) ΔD = A_is_arg ? make_zero(D.dval) : D.dval - $pf!(A.dval, A.val, (Diagonal(diagview(Dval)), V), ΔD) + $pf!(A.dval, A.val, DV, ΔD) A_is_arg && (D.dval .= ΔD) end copyto!(D.val, diagview(Dval)) @@ -464,16 +458,14 @@ function EnzymeRules.augmented_primal( S::Annotation{TS}, alg::Annotation{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TS} - # A is overwritten in the primal, but NOT used in the pullback, - # so we do not need to cache it. This may change if future pullbacks - # depend directly on A! + # A is overwritten in the primal, but not used in the pullback, so we do not need to cache it. U, nS, Vᴴ = svd_compact!(A.val, alg.val) ret = TS == Nothing ? diagview(nS) : copy!(S.val, diagview(nS)) cache_S = (S.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing primal = EnzymeRules.needs_primal(config) ? ret : nothing # on 1.10, Enzyme can get confused about whether it needs the shadow # create dret no matter what to account for this - dret = TS == Nothing || isa(S, Const) ? zero(ret) : S.dval + dret = TS == Nothing || isa(S, Const) ? make_zero(ret) : S.dval shadow = EnzymeRules.needs_shadow(config) ? dret : nothing return EnzymeRules.AugmentedReturn(primal, shadow, (cache_S, dret, U, Vᴴ)) end @@ -487,20 +479,45 @@ function EnzymeRules.reverse( alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, ) where {RT, TA} cache_S, dS, U, Vᴴ = cache - # A is NOT used in the pullback, so we assign Aval = nothing - # to trigger an error in case the pullback is modified to directly - # use A (so that whoever does this is forced to handle caching A - # appropriately here) + # A is not used in the pullback; since we have destroyed A, we insert Aval = nothing + # to trigger an error in case the pushfward is modified to directly use A Aval = nothing Sval = something(cache_S, S.val) - A_is_arg = !isa(A, Const) && TA <: Diagonal && diagview(A.dval) === S.dval + A_is_arg = !isa(A, Const) && has_equal_storage(A.dval, S.dval) if !isa(A, Const) ΔA = A_is_arg ? make_zero(A.dval) : A.dval - svd_vals_pullback!(ΔA, Aval, (U, Diagonal(Sval), Vᴴ), dS) + svd_vals_pullback!(ΔA, Aval, (U, diagonal(Sval), Vᴴ), dS) A_is_arg && (A.dval .= ΔA) end !isa(S, Const) && !A_is_arg && make_zero!(S.dval) return (nothing, nothing, nothing) end +function EnzymeRules.forward( + config::EnzymeRules.FwdConfigWidth{1}, + func::Const{typeof(svd_vals!)}, + ::Type{RT}, + A::Annotation{TA}, + S::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + A_is_arg = !isa(A, Const) && has_equal_storage(A.dval, S.dval) + USVᴴ = svd_compact!(A.val, alg.val) + if !isa(A, Const) && !isa(S, Const) + ΔS = A_is_arg ? make_zero(S.dval) : S.dval + svd_vals_pushforward!(A.dval, A.val, USVᴴ, ΔS) + A_is_arg && (S.dval .= ΔS) + end + !A_is_arg && make_zero!(A.dval) + copyto!(S.val, diagview(USVᴴ[2])) + if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config) + return S + elseif EnzymeRules.needs_primal(config) + return S.val + elseif EnzymeRules.needs_shadow(config) + return S.dval + else + return nothing + end +end end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 16241385b..da857136f 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt using Mooncake using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive using MatrixAlgebraKit -using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero! +using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output, zero!, has_equal_storage using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! @@ -13,6 +13,7 @@ using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pul using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! using MatrixAlgebraKit: left_polar_pushforward!, right_polar_pushforward! using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! +using MatrixAlgebraKit: svd_pushforward!, svd_vals_pushforward! using MatrixAlgebraKit: TruncatedAlgorithm using LinearAlgebra @@ -66,17 +67,17 @@ for (f!, f, pb, adj) in ( # of the output args -- this can # mess up the pullback because # generally the args are used there - if !(A === arg1 || A === arg2) + if !(has_equal_storage(A, arg1) || has_equal_storage(A, arg2)) copy!(A, Ac) $pb(dA, A, (arg1, arg2), (darg1, darg2)) else ΔA = zero(A) - $pb(ΔA, A, (arg1, arg2), (darg1, darg2)) + $pb(ΔA, Ac, (arg1, arg2), (darg1, darg2)) dA .= ΔA end - if A === arg1 + if has_equal_storage(A, arg1) zero!(darg2) - elseif A === arg2 + elseif has_equal_storage(A, arg2) zero!(darg1) else zero!(darg1) @@ -92,11 +93,7 @@ for (f!, f, pb, adj) in ( function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) output = $f(A, Mooncake.primal(alg_dalg)) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + output_codual = Mooncake.zero_fcodual(output) function $adj(::NoRData) arg1, arg2 = Mooncake.primal(output_codual) darg1_, darg2_ = Mooncake.tangent(output_codual) @@ -127,7 +124,9 @@ for (f!, f, pf) in ( arg1, darg1 = arrayify(args[1], dargs[1]) arg2, darg2 = arrayify(args[2], dargs[2]) $f!(A, args, Mooncake.primal(alg_dalg)) - $pf(dA, A, (arg1, arg2), (darg1, darg2)) + # A is not used in the pushforward; since we have destroyed A, we insert A = nothing + # to trigger an error in case the pushfward is modified to directly use A + $pf(dA, nothing, (arg1, arg2), (darg1, darg2)) return args_dargs end @is_primitive Mooncake.DefaultCtx Mooncake.ForwardMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} @@ -169,7 +168,7 @@ for (f!, f, pb, adj) in ( function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) output = $f(A, Mooncake.primal(alg_dalg)) - output_codual = CoDual(output, Mooncake.zero_tangent(output)) + output_codual = Mooncake.zero_fcodual(output) function $adj(::NoRData) arg, darg = arrayify(output_codual) $pb(dA, A, arg, darg) @@ -181,9 +180,9 @@ for (f!, f, pb, adj) in ( end end -for (f!, f, f_full, pb, pf, adj) in ( - (:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_pushforward!, :eig_vals_adjoint), - (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_pushforward!, :eigh_vals_adjoint), +for (f!, f, f_full, f_full!, pb, pf, adj) in ( + (:eig_vals!, :eig_vals, :eig_full, :eig_full!, :eig_vals_pullback!, :eig_vals_pushforward!, :eig_vals_adjoint), + (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_full!, :eigh_vals_pullback!, :eigh_vals_pushforward!, :eigh_vals_adjoint), ) @eval begin @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} @@ -191,38 +190,45 @@ for (f!, f, f_full, pb, pf, adj) in ( # compute primal A, dA = arrayify(A_dA) D, dD = arrayify(D_dD) - Dc = copy(D) - # update primal DV = $f_full(A, Mooncake.primal(alg_dalg)) + Ac = has_equal_storage(A, D) ? copy(A) : A + Dc = copy(D) copy!(D, diagview(DV[1])) V = DV[2] function $adj(::NoRData) - if A !== D + if !has_equal_storage(A, D) # A is unchanged $pb(dA, A, DV, dD) - else - ΔA = zero(A) - $pb(ΔA, A, DV, dD) - dA .= A - end - if A !== D zero!(dD) copy!(D, Dc) - else + else # A and D have the same storage + ΔA = zero(A) + $pb(ΔA, Ac, DV, dD) + dA .= ΔA copy!(A, Ac) end return NoRData(), NoRData(), NoRData(), NoRData() end return D_dD, $adj end - function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual) + function Mooncake.frule!!(f_df::Dual{typeof($f!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual) # compute primal A, dA = arrayify(A_dA) D, dD = arrayify(D_dD) - # update primal - DV = $f_full(A, Mooncake.primal(alg_dalg)) - V = DV[2] - copyto!(D, diagview(DV[1])) - $pf(dA, A, (D, V), dD) + # have to do it like this to make Mooncake tests pass for both the case when D === A.diag and when not + _, V = initialize_output($f_full!, A, Mooncake.primal(alg_dalg)) + DV = (diagonal(D), V) + DV = $f_full!(A, DV, Mooncake.primal(alg_dalg)) + Dmat = DV[1] + if !(has_equal_storage(Dmat, D)) + copy!(D, diagview(Dmat)) + end + # A is not used in the pushforward; since we have destroyed A, we insert A = nothing + # to trigger an error in case the pushfward is modified to directly use A + if !(has_equal_storage(dA, dD)) + $pf(dA, nothing, DV, dD) + else + $pf(copy(dA), nothing, DV, dD) + end return D_dD end @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} @@ -231,16 +237,15 @@ for (f!, f, f_full, pb, pf, adj) in ( A, dA = arrayify(A_dA) # update primal DV = $f_full(A, Mooncake.primal(alg_dalg)) - V = DV[2] - output = diagview(DV[1]) - output_codual = CoDual(output, Mooncake.zero_tangent(output)) + D = diagview(DV[1]) + D_codual = Mooncake.zero_fcodual(D) function $adj(::NoRData) - D, dD = arrayify(output_codual) + D, dD = arrayify(D_codual) $pb(dA, A, DV, dD) zero!(dD) return NoRData(), NoRData(), NoRData() end - return output_codual, $adj + return D_codual, $adj end function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) # compute primal @@ -248,11 +253,11 @@ for (f!, f, f_full, pb, pf, adj) in ( # update primal DV = $f_full(A, Mooncake.primal(alg_dalg)) V = DV[2] - output = diagview(DV[1]) - output_dual = Dual(output, Mooncake.zero_tangent(output)) - D, dD = arrayify(output_dual) + D = diagview(DV[1]) + D_dual = Dual(D, Mooncake.zero_tangent(D)) + _, dD = arrayify(D_dual) $pf(dA, A, DV, dD) - return output_dual + return D_dual end end end @@ -283,10 +288,6 @@ for f in (:eig, :eigh) DVc = copy.(DV) alg = Mooncake.primal(alg_dalg) output = $f_trunc!(A, DV, alg) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) output_codual = Mooncake.zero_fcodual(output) function $f_adjoint!(dy::Tuple{NoRData, NoRData, <:Real}) Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) @@ -297,7 +298,7 @@ for f in (:eig, :eigh) D, dD = arrayify(DV[1], dDV[1]) V, dV = arrayify(DV[2], dDV[2]) copy!(A, Ac) - if !(A === D || A === V) + if !has_equal_storage(A, D) # A is unchanged $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) else ΔA = zero(A) @@ -368,11 +369,7 @@ for f in (:eig, :eigh) A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) output = $f_trunc(A, alg) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + output_codual = Mooncake.zero_fcodual(output) function $f_adjoint!(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) @@ -421,11 +418,7 @@ for f in (:eig, :eigh) Ac = copy(A) DVc = copy.(DV) output = $f_trunc_no_error!(A, DV, alg) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + output_codual = Mooncake.zero_fcodual(output) function $f_adjoint!(::NoRData) copy!(A, Ac) Dtrunc, Vtrunc = Mooncake.primal(output_codual) @@ -433,8 +426,7 @@ for f in (:eig, :eigh) D′, dD′ = arrayify(Dtrunc, dDtrunc_) V′, dV′ = arrayify(Vtrunc, dVtrunc_) $f_pullback!(dA, A, (D′, V′), (dD′, dV′)) - copy!(DV[1], DVc[1]) - copy!(DV[2], DVc[2]) + copy!.(DV, DVc) zero!(dD′) zero!(dV′) return NoRData(), NoRData(), NoRData(), NoRData() @@ -491,11 +483,7 @@ for f in (:eig, :eigh) A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) output = $f_trunc_no_error(A, alg) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + output_codual = Mooncake.zero_fcodual(output) function $f_adjoint!(::NoRData) Dtrunc, Vtrunc = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual) @@ -538,7 +526,7 @@ for (f!, f) in ( (:svd_compact!, :svd_compact), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) @@ -552,9 +540,7 @@ for (f!, f) in ( function svd_adjoint(::NoRData) copy!(A, Ac) svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) - copy!(U, USVᴴc[1]) - copy!(S, USVᴴc[2]) - copy!(Vᴴ, USVᴴc[3]) + copy!.(USVᴴ, USVᴴc) zero!(dU) zero!(dS) zero!(dVᴴ) @@ -562,15 +548,24 @@ for (f!, f) in ( end return USVᴴ_dUSVᴴ, svd_adjoint end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual) + A, dA = arrayify(A_dA) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + $f!(A, USVᴴ, Mooncake.primal(alg_dalg)) + # A is not used in the pushforward; since we have destroyed A, we insert A = nothing + # to trigger an error in case the pushfward is modified to directly use A + svd_pushforward!(dA, nothing, (U, S, Vᴴ), (dU, dS, dVᴴ)) + return USVᴴ_dUSVᴴ + end + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) USVᴴ = $f(A, Mooncake.primal(alg_dalg)) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) - USVᴴ_codual = CoDual(USVᴴ, Mooncake.fdata(Mooncake.zero_tangent(USVᴴ))) + USVᴴ_codual = Mooncake.zero_fcodual(USVᴴ) function svd_adjoint(::NoRData) U, S, Vᴴ = Mooncake.primal(USVᴴ_codual) dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_codual) @@ -585,10 +580,23 @@ for (f!, f) in ( end return USVᴴ_codual, svd_adjoint end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + A, dA = arrayify(A_dA) + USVᴴ = $f(A, Mooncake.primal(alg_dalg)) + dUSVᴴ = Mooncake.zero_tangent(USVᴴ) + USVᴴ_dual = Dual(USVᴴ, dUSVᴴ) + U, S, Vᴴ = Mooncake.primal(USVᴴ_dual) + dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_dual) + U, dU = arrayify(U, dU_) + S, dS = arrayify(S, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + return USVᴴ_dual + end end end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -604,18 +612,25 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua end return S_dS, svd_vals_adjoint end +function Mooncake.frule!!(::Dual{typeof(svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + S, dS = arrayify(S_dS) + USVᴴ = svd_compact!(A, Mooncake.primal(alg_dalg)) + copy!(S, diagview(USVᴴ[2])) + # A is not used in the pushforward; since we have destroyed A, we insert A = nothing + # to trigger an error in case the pushfward is modified to directly use A + svd_vals_pushforward!(dA, nothing, USVᴴ, dS) + return S_dS +end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) S = diagview(USVᴴ[2]) - S_codual = CoDual(S, Mooncake.fdata(Mooncake.zero_tangent(S))) + S_codual = Mooncake.zero_fcodual(S) function svd_vals_adjoint(::NoRData) S, dS = arrayify(S_codual) svd_vals_pullback!(dA, A, USVᴴ, dS) @@ -624,6 +639,16 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co end return S_codual, svd_vals_adjoint end +function Mooncake.frule!!(::Dual{typeof(svd_vals)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + S = diagview(USVᴴ[2]) + S_dual = Dual(S, Mooncake.zero_tangent(S)) + S_, dS = arrayify(S_dual) + svd_vals_pushforward!(dA, A, USVᴴ, dS) + return S_dual +end @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) @@ -638,10 +663,6 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) USVᴴc = copy.(USVᴴ) output = svd_trunc!(A, USVᴴ, alg) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) output_codual = Mooncake.zero_fcodual(output) function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} copy!(A, Ac) @@ -652,9 +673,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS S′, dS′ = arrayify(Strunc, dStrunc_) Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_) svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′)) - copy!(U, USVᴴc[1]) - copy!(S, USVᴴc[2]) - copy!(Vᴴ, USVᴴc[3]) + copy!.(USVᴴ, USVᴴc) zero!(dU) zero!(dS) zero!(dVᴴ) @@ -711,11 +730,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) output = svd_trunc(A, alg) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + output_codual = Mooncake.zero_fcodual(output) function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) @@ -769,11 +784,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) USVᴴc = copy.(USVᴴ) output = svd_trunc_no_error!(A, USVᴴ, alg) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + output_codual = Mooncake.zero_fcodual(output) function svd_trunc_adjoint(::NoRData) copy!(A, Ac) Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual) @@ -837,11 +848,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al A, dA = arrayify(A_dA) alg = Mooncake.primal(alg_dalg) output = svd_trunc_no_error(A, alg) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + output_codual = Mooncake.zero_fcodual(output) function svd_trunc_adjoint(::NoRData) Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 65de152c4..0f69822cd 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -92,6 +92,7 @@ include("common/safemethods.jl") include("common/view.jl") include("common/regularinv.jl") include("common/matrixproperties.jl") +include("common/utility.jl") include("yalapack.jl") include("algorithms.jl") @@ -132,6 +133,7 @@ include("pullbacks/polar.jl") include("pushforwards/polar.jl") include("pushforwards/eig.jl") include("pushforwards/eigh.jl") +include("pushforwards/svd.jl") include("precompile.jl") diff --git a/src/common/utility.jl b/src/common/utility.jl new file mode 100644 index 000000000..db49493bb --- /dev/null +++ b/src/common/utility.jl @@ -0,0 +1,16 @@ + +function has_equal_storage(A::Diagonal, B::Diagonal) + return diagview(A) === diagview(B) +end +function has_equal_storage(A::AbstractMatrix, B::AbstractMatrix) + return A === B +end + +function has_equal_storage(A::Diagonal, B::AbstractVector) + return diagview(A) === B +end +function has_equal_storage(A::AbstractVector, B::Diagonal) + return A === diagview(B) +end +has_equal_storage(A::AbstractMatrix, B::AbstractVector) = false +has_equal_storage(A::AbstractVector, B::AbstractMatrix) = false \ No newline at end of file diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 8b09d7ad3..f8eddad0b 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -165,7 +165,7 @@ function eig_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) D, V = DV diagA = diagview(A) I = sortperm(diagA; by = eig_sortby) - if D === A + if has_equal_storage(A, D) permute!(diagA, I) else diagview(D) .= view(diagA, I) @@ -179,8 +179,8 @@ end function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm) check_input(eig_vals!, A, D, alg) - Ad = diagview(A) - D === Ad || copy!(D, Ad) + diagA = diagview(A) + has_equal_storage(A, D) || copy!(D, diagA) sort!(D; by = eig_sortby) return D end diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index cce8577bc..29744c93d 100755 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -101,11 +101,15 @@ function eig_pullback!( end return ΔA end +# Diagonal: do not specialize on `A`, since we may insert `A = nothing` to assert independence of `A` in the implementation function eig_pullback!( ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) + # TODO: + # If A and ΔA are diagonal, then V is a permutation matrix and so is inv(V) = V'. + # Furthermore, since V̇ is 0, the pullback ΔV cannot contribute and we only have to unpermute ΔD. ΔA_full = zero!(similar(ΔA, size(ΔA))) ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) diagview(ΔA) .+= diagview(ΔA_full) diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 7a2f1365c..e2687a427 100755 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -94,11 +94,15 @@ function eigh_pullback!( ΔA = mul!(ΔA, V * VᴴΔAV, V', 1, 1) return ΔA end +# Diagonal: do not specialize on `A`, since we may insert `A = nothing` to assert independence of `A` in the implementation function eigh_pullback!( ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) ) + # TODO: + # If A and ΔA are diagonal, then V is a permutation matrix and so is inv(V) = V'. + # Furthermore, since V̇ is 0, the pullback ΔV cannot contribute and we only have to unpermute ΔD. ΔA_full = zero!(similar(ΔA, size(ΔA))) ΔA_full = eigh_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) diagview(ΔA) .+= diagview(ΔA_full) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index de1f91fa8..5d5020535 100755 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -171,12 +171,16 @@ function svd_pullback!( end return ΔA end +# Diagonal: do not specialize on `A`, since we may insert `A = nothing` to assert independence of `A` in the implementation function svd_pullback!( ΔA::Diagonal, A, USVᴴ, ΔUSVᴴ, ind = Colon(); rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ...) ) + # TODO: + # If A and ΔA are diagonal, then U and V are permutation matrices (up to signs/phases). + # Furthermore, since U̇ and V̇ are 0, the pullbacks ΔU and ΔV cannot contribute and we only have to unpermute ΔS. ΔA_full = zero!(similar(ΔA, size(ΔA))) ΔA_full = svd_pullback!(ΔA_full, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol, gauge_atol) diagview(ΔA) .+= diagview(ΔA_full) diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl index 9e39f6395..46f3de1e8 100644 --- a/src/pushforwards/eig.jl +++ b/src/pushforwards/eig.jl @@ -1,7 +1,6 @@ function eig_pushforward!( ΔA, A, DV, ΔDV; - degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]) ) D, V = DV ΔD, ΔV = ΔDV @@ -12,11 +11,16 @@ function eig_pushforward!( end if !iszerotangent(ΔV) ∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol) - mul!(ΔV, V, ∂K, 1, 0) + mul!(ΔV, V, ∂K) + if eltype(V) <: Complex # fix gauge for `gaugefix!` compatibility + _, I = findmax(abs, V; dims = 1) + infinitesimal_phases = imag.(ΔV[I] ./ V[I]) + ΔV .-= im .* V .* infinitesimal_phases + end end return ΔDV end function eig_vals_pushforward!(ΔA, A, DV, ΔD; kwargs...) - return eig_pushforward!(ΔA, A, DV, (Diagonal(ΔD), nothing); kwargs...) + return eig_pushforward!(ΔA, A, DV, (diagonal(ΔD), nothing); kwargs...) end diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl index e610867fe..894bd97cb 100644 --- a/src/pushforwards/eigh.jl +++ b/src/pushforwards/eigh.jl @@ -1,7 +1,6 @@ function eigh_pushforward!( ΔA, A, DV, ΔDV; - degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), - gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]) ) D, V = DV ΔD, ΔV = ΔDV @@ -13,6 +12,11 @@ function eigh_pushforward!( if !iszerotangent(ΔV) ∂K .*= inv_safe.(transpose(diagview(D)) .- diagview(D), degeneracy_atol) ΔV = mul!(ΔV, V, ∂K) + if eltype(V) <: Complex # fix gauge for `gaugefix!` compatibility + _, I = findmax(abs, V; dims = 1) + infinitesimal_phases = imag.(ΔV[I] ./ V[I]) + ΔV .-= im .* V .* infinitesimal_phases + end end return (ΔD, ΔV) end diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl new file mode 100644 index 000000000..958ded7dc --- /dev/null +++ b/src/pushforwards/svd.jl @@ -0,0 +1,82 @@ +function svd_pushforward!( + ΔA, A, USVᴴ, ΔUSVᴴ, ind = Colon(); + rank_atol = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol = default_pullback_rank_atol(USVᴴ[2]) + ) + U, Smat, Vᴴ = USVᴴ + m, n = size(U, 1), size(Vᴴ, 2) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) + minmn = min(m, n) + S = diagview(Smat) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + r = svd_rank(S; rank_atol) + + U₁ = view(U, :, 1:r) + S₁ = view(S, 1:r) + V₁ᴴ = view(Vᴴ, 1:r, :) + + # compact region + V₁ = adjoint(V₁ᴴ) + ΔAV₁ = ΔA * V₁ + UᴴΔAV₁ = U₁' * ΔAV₁ + if !iszerotangent(ΔS) + zero!(ΔS) # make off-diagonal entries zero in case of full ΔS (svd_full!) + ΔS₁ = view(diagview(ΔS), 1:r) + ΔS₁ .= real.(diagview(UᴴΔAV₁)) + end + if !iszerotangent(ΔU) || !iszerotangent(ΔVᴴ) + hUᴴΔAV₁ = inv_safe.(transpose(S₁) .- S₁, degeneracy_atol) .* project_hermitian(UᴴΔAV₁) + aUᴴΔAV₁ = inv_safe.(transpose(S₁) .+ S₁, degeneracy_atol) .* project_antihermitian(UᴴΔAV₁) + if !iszerotangent(ΔU) + ΔU₁ = view(ΔU, :, 1:r) + K̇ = hUᴴΔAV₁ + aUᴴΔAV₁ + mul!(ΔU₁, U₁, K̇) + if m > r + ΔAV₁ = mul!(ΔAV₁, U₁, UᴴΔAV₁, -1, 1) + ΔU₁ .+= ΔAV₁ ./ transpose(S₁) + end + if size(U, 2) > r # these columns of U are undetermined, but U' * U̇ should be antihermitian + U₂ = view(U, :, (r + 1):size(U, 2)) + ΔU₁ᴴU₂ = ΔU₁' * U₂ + ΔU₂ = view(ΔU, :, (r + 1):size(U, 2)) + mul!(ΔU₂, U₁, ΔU₁ᴴU₂, -1, 0) + end + end + if !iszerotangent(ΔVᴴ) + ΔV₁ᴴ = view(ΔVᴴ, 1:r, :) + Ṁ = hUᴴΔAV₁ - aUᴴΔAV₁ + mul!(ΔV₁ᴴ, Ṁ', V₁ᴴ) + if n > r + UᴴΔA₁ = U₁' * ΔA + UᴴΔA₁ = mul!(UᴴΔA₁, UᴴΔAV₁, V₁ᴴ, -1, 1) + ΔV₁ᴴ .+= S₁ .\ UᴴΔA₁ + end + if size(Vᴴ, 1) > r # these rows of Vᴴ are undetermined, but V * V̇ should be antihermitian + V₂ᴴ = view(Vᴴ, (r + 1):size(Vᴴ, 1), :) + V₂ᴴΔV₁ = V₂ᴴ * ΔV₁ᴴ' + ΔV₂ᴴ = view(ΔVᴴ, (r + 1):size(Vᴴ, 1), :) + mul!(ΔV₂ᴴ, V₂ᴴΔV₁, V₁ᴴ, -1, 0) + end + end + if eltype(U) <: Complex && !iszerotangent(ΔU) && !iszerotangent(ΔVᴴ) # fix gauge for `gaugefix!` compatibility + _, I = findmax(abs, U₁; dims = 1) + infinitesimal_phases = imag.(ΔU₁[I] .* inv_safe.(U₁[I])) + ΔU₁ .-= im .* U₁ .* infinitesimal_phases + ΔV₁ᴴ .+= im .* transpose(infinitesimal_phases) .* V₁ᴴ + end + end + return (ΔU, ΔS, ΔVᴴ) +end + +# TODO +#=function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol = default_pullback_rank_atol(A), kwargs...) +end=# + +function svd_vals_pushforward!( + ΔA, A, USVᴴ, ΔS, ind = Colon(); + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]) + ) + ΔUSVᴴ = (nothing, diagonal(ΔS), nothing) + return svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol) +end diff --git a/test/enzyme/svd.jl b/test/enzyme/svd.jl index e4aaa7aa1..bef41e5c7 100644 --- a/test/enzyme/svd.jl +++ b/test/enzyme/svd.jl @@ -3,7 +3,7 @@ using Test using LinearAlgebra: Diagonal using CUDA, AMDGPU -BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +BLASFloats = (Float64, ComplexF64) # full suite is too expensive on CI GenericFloats = () @isdefined(TestSuite) || include("../testsuite/TestSuite.jl") using .TestSuite @@ -16,6 +16,6 @@ for T in (BLASFloats..., GenericFloats...), n in (17, m, 23) if !is_buildkite TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T)) AT = Diagonal{T, Vector{T}} - m == n && TestSuite.test_enzyme_svd(AT, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + m == n && TestSuite.test_enzyme_svd(AT, m; atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) end end diff --git a/test/testsuite/enzyme/svd.jl b/test/testsuite/enzyme/svd.jl index 2131aa8d5..07a35d6d4 100644 --- a/test/testsuite/enzyme/svd.jl +++ b/test/testsuite/enzyme/svd.jl @@ -8,48 +8,78 @@ function test_enzyme_svd(T::Type, sz; kwargs...) end end +""" + test_enzyme_svd_compact(T, sz; rng, atol, rtol) + +Test the Enzyme forward- and reverse-mode AD rule for `svd_compact` and its in-place variant. +""" function test_enzyme_svd_compact( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "svd_compact reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "svd_compact: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(svd_compact, A) USVᴴ, ΔUSVᴴ = ad_svd_compact_setup(A) test_reverse(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) - test_reverse(call_and_zero!, RT, (svd_compact!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_reverse(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_forward(svd_compact, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_compact!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) end end +""" + test_enzyme_svd_full(T, sz; rng, atol, rtol) + +Test the Enzyme forward- and reverse-mode AD rule for `svd_full` and its in-place variant. The +gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent. +""" function test_enzyme_svd_full( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "svd_full reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "svd_full: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(svd_full, A) USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) test_reverse(svd_full, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) - test_reverse(call_and_zero!, RT, (svd_full!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + test_reverse(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔUSVᴴ, fdm) + if size(A, 1) == size(A, 2) # finite differences check for free component is very finicky + test_forward(svd_full, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_full!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) + end end end +""" + test_enzyme_svd_vals(T, sz; rng, atol, rtol) + +Test the Enzyme forward- and reverse-mode AD rule for `svd_vals` and its in-place variant. +""" function test_enzyme_svd_vals( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), fdm = enzyme_fdm(T) ) - return @testset "svd_vals reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + return @testset "svd_vals: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) A = instantiate_matrix(T, sz) alg = MatrixAlgebraKit.select_algorithm(svd_vals, A) S, ΔS = ad_svd_vals_setup(A) test_reverse(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) - test_reverse(call_and_zero!, RT, (svd_vals!, Const), (A, TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) + test_reverse(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, output_tangent = ΔS, fdm) + test_forward(svd_vals, RT, (A, TA), (alg, Const); atol, rtol, fdm) + test_forward(call_and_zero!, RT, (svd_vals!, Const), (copy(A), TA), (alg, Const); atol, rtol, fdm) end end +""" + test_enzyme_svd_trunc(T, sz; rng, atol, rtol) + +Test the Enzyme reverse-mode AD rules for `svd_trunc`, `svd_trunc_no_error`, and their +in-place variants, over a range of truncation ranks and a tolerance-based truncation. +""" function test_enzyme_svd_trunc( T, sz; rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T), @@ -64,7 +94,7 @@ function test_enzyme_svd_trunc( trunc = truncrank(r) truncalg = TruncatedAlgorithm(alg, trunc) USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) end @testset "trunctol" begin @@ -72,7 +102,7 @@ function test_enzyme_svd_trunc( trunc = trunctol(atol = maximum(S) / 2) truncalg = TruncatedAlgorithm(alg, trunc) USVᴴ, _, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - test_reverse(svd_trunc_no_error, RT, (A, TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) + test_reverse(svd_trunc_no_error, RT, (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) test_reverse(call_and_zero!, RT, (svd_trunc_no_error!, Const), (copy(A), TA), (truncalg, Const); atol, rtol, output_tangent = ΔUSVᴴtrunc, fdm) end end diff --git a/test/testsuite/mooncake/eig.jl b/test/testsuite/mooncake/eig.jl index 3cc0063c7..8c9c29199 100644 --- a/test/testsuite/mooncake/eig.jl +++ b/test/testsuite/mooncake/eig.jl @@ -77,9 +77,15 @@ function test_mooncake_eig_vals( rng, eig_vals, A, alg; output_tangent, atol, rtol ) - if T <: Diagonal{<:Complex} + if A isa Diagonal{<:Complex} + A2 = copy(A) + Mooncake.TestUtils.test_rule( + rng, eig_vals!, A2, copy(A2.diag), alg; + output_tangent, atol, rtol + ) + A2 = copy(A) Mooncake.TestUtils.test_rule( - rng, eig_vals!, A, A.diag, alg; + rng, eig_vals!, A2, A2.diag, alg; output_tangent, atol, rtol ) end diff --git a/test/testsuite/mooncake/eigh.jl b/test/testsuite/mooncake/eigh.jl index 00f29a6f7..481c4ea54 100644 --- a/test/testsuite/mooncake/eigh.jl +++ b/test/testsuite/mooncake/eigh.jl @@ -67,6 +67,18 @@ function test_mooncake_eigh_vals( rng, eigh_wrapper, eigh_vals, A, alg; output_tangent, is_primitive = false, atol, rtol ) + if A isa Diagonal{<:Real} + A2 = copy(A) + Mooncake.TestUtils.test_rule( + rng, eig_vals!, A2, copy(A2.diag), alg; + output_tangent, atol, rtol + ) + A2 = copy(A) + Mooncake.TestUtils.test_rule( + rng, eig_vals!, A2, A2.diag, alg; + output_tangent, atol, rtol + ) + end Mooncake.TestUtils.test_rule( rng, eigh!_wrapper, eigh_vals!, A, alg; output_tangent, is_primitive = false, atol, rtol diff --git a/test/testsuite/mooncake/svd.jl b/test/testsuite/mooncake/svd.jl index 5ac79744e..de9dbc543 100644 --- a/test/testsuite/mooncake/svd.jl +++ b/test/testsuite/mooncake/svd.jl @@ -16,7 +16,7 @@ end """ test_mooncake_svd_compact(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `svd_compact` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `svd_compact` and its in-place variant. """ function test_mooncake_svd_compact( T, sz; @@ -30,11 +30,11 @@ function test_mooncake_svd_compact( Mooncake.TestUtils.test_rule( rng, svd_compact, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, svd_compact!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + output_tangent, atol, rtol, is_primitive = false ) end end @@ -42,7 +42,7 @@ end """ test_mooncake_svd_full(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `svd_full` and its in-place variant. The +Test the Mooncake forward- and reverse-mode AD rule for `svd_full` and its in-place variant. The gauge-dependent extra columns of `U` and rows of `Vᴴ` are zeroed out in the cotangent. """ function test_mooncake_svd_full( @@ -55,13 +55,15 @@ function test_mooncake_svd_full( USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) output_tangent = Mooncake.primal_to_tangent!!(Mooncake.zero_tangent(USVᴴ), ΔUSVᴴ) + mode = (size(A, 1) == size(A, 2)) ? nothing : Mooncake.ReverseMode + # svd_full of nonsquare has no well-determined forward derivative Mooncake.TestUtils.test_rule( rng, svd_full, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + mode = mode, output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( - rng, call_and_zero!, svd_full!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + rng, call_and_zero!, svd_full!, copy(A), alg; + mode = mode, output_tangent, atol, rtol, is_primitive = false ) end end @@ -69,7 +71,7 @@ end """ test_mooncake_svd_vals(T, sz; rng, atol, rtol) -Test the Mooncake reverse-mode AD rule for `svd_vals` and its in-place variant. +Test the Mooncake forward- and reverse-mode AD rule for `svd_vals` and its in-place variant. """ function test_mooncake_svd_vals( T, sz; @@ -83,11 +85,11 @@ function test_mooncake_svd_vals( Mooncake.TestUtils.test_rule( rng, svd_vals, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol + output_tangent, atol, rtol ) Mooncake.TestUtils.test_rule( rng, call_and_zero!, svd_vals!, A, alg; - mode = Mooncake.ReverseMode, output_tangent, atol, rtol, is_primitive = false + output_tangent, atol, rtol, is_primitive = false ) end end