diff --git a/ext/TensorKitMooncakeExt/factorizations.jl b/ext/TensorKitMooncakeExt/factorizations.jl index 3bb1b3ae3..9927c3b5c 100644 --- a/ext/TensorKitMooncakeExt/factorizations.jl +++ b/ext/TensorKitMooncakeExt/factorizations.jl @@ -1,63 +1,2 @@ -for f in (:svd_compact, :svd_full) - f_pullback = Symbol(f, :_pullback) - @eval begin - @is_primitive DefaultCtx ReverseMode Tuple{typeof($f), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm} - function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractTensorMap}, alg_dalg::CoDual) - A, dA = arrayify(A_dA) - alg = primal(alg_dalg) - - USVᴴ = $f(A, primal(alg_dalg)) - USVᴴ_dUSVᴴ = Mooncake.zero_fcodual(USVᴴ) - dUSVᴴ = last.(arrayify.(USVᴴ, tangent(USVᴴ_dUSVᴴ))) - - function $f_pullback(::NoRData) - MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴ) - MatrixAlgebraKit.zero!.(dUSVᴴ) - return ntuple(Returns(NoRData()), 3) - end - - return USVᴴ_dUSVᴴ, $f_pullback - end - end - - # mutating version is not guaranteed to actually mutate - # so we can simply use the non-mutating version instead and avoid having to worry about - # storing copies and restoring state - f! = Symbol(f, :!) - f!_pullback = Symbol(f!, :_pullback) - @eval begin - @is_primitive DefaultCtx ReverseMode Tuple{typeof($f!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm} - Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) = - Mooncake.rrule!!(Mooncake.zero_fcodual($f), A_dA, alg_dalg) - end -end - -@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc), AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm} -function Mooncake.rrule!!( - ::CoDual{typeof(svd_trunc)}, - A_dA::CoDual{<:AbstractTensorMap}, - alg_dalg::CoDual{<:MatrixAlgebraKit.TruncatedAlgorithm} - ) - A, dA = arrayify(A_dA) - alg = primal(alg_dalg) - - USVᴴ = svd_compact(A, alg.alg) - USVᴴtrunc, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - ϵ = MatrixAlgebraKit.truncation_error(diagview(USVᴴ[2]), ind) - - USVᴴtrunc_dUSVᴴtrunc = Mooncake.zero_fcodual((USVᴴtrunc..., ϵ)) - dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Base.front(tangent(USVᴴtrunc_dUSVᴴtrunc)))) - - function svd_trunc_pullback((_, _, _, dϵ)::Tuple{NoRData, NoRData, NoRData, Real}) - abs(dϵ) ≤ MatrixAlgebraKit.defaulttol(dϵ) || - @warn "Gradient for `svd_trunc` ignores non-zero tangents for truncation error" - MatrixAlgebraKit.svd_pullback!(dA, A, USVᴴ, dUSVᴴtrunc, ind) - return ntuple(Returns(NoRData()), 3) - end - - return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_pullback -end - -@is_primitive DefaultCtx ReverseMode Tuple{typeof(svd_trunc!), AbstractTensorMap, Any, MatrixAlgebraKit.AbstractAlgorithm} -Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual{<:AbstractTensorMap}, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) = - Mooncake.rrule!!(Mooncake.zero_fcodual(svd_trunc), A_dA, alg_dalg) +# needed for the ising bimodule case +@zero_derivative DefaultCtx Tuple{typeof(MatrixAlgebraKit.initialize_output), Any, AbstractTensorMap, MatrixAlgebraKit.AbstractAlgorithm} diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl index 5f47a5260..755799370 100644 --- a/ext/TensorKitMooncakeExt/tensoroperations.jl +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -250,3 +250,20 @@ function trace_permute_pullback_ΔA!( ) return NoRData() end + +@is_primitive( + DefaultCtx, + Tuple{ + typeof(TensorKit.scalar), + AbstractTensorMap, + } +) +function Mooncake.rrule!!(::CoDual{typeof(TensorKit.scalar)}, t_dt::CoDual{<:AbstractTensorMap}) + t, dt = arrayify(t_dt) + val = scalar(t) + function scalar_pullback(Δval) + first(blocks(dt))[2][1] = Δval + return NoRData(), NoRData() + end + return Mooncake.zero_fcodual(val), scalar_pullback +end diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl index ceb32d867..b54fda7cf 100644 --- a/ext/TensorKitMooncakeExt/utility.jl +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -65,6 +65,12 @@ Mooncake.tangent_type(::Type{<:HomSpace}) = Mooncake.NoTangent @zero_derivative DefaultCtx Tuple{typeof(TensorKit.sectorstructure), Any} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.degeneracystructure), Any} +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), AbstractTensorMap} +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), AbstractTensorMap, Int, Bool} + +@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_structure), AbstractTensorMap, Index2Tuple, Bool, AbstractTensorMap, Index2Tuple, Bool, Index2Tuple} + +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.has_shared_permute), AbstractTensorMap, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.select), HomSpace, Index2Tuple} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.flip), HomSpace, Any} @zero_derivative DefaultCtx Tuple{typeof(TensorKit.permute), HomSpace, Index2Tuple} diff --git a/test/mooncake/factorizations.jl b/test/mooncake/factorizations.jl index 8955c4ecf..10810c979 100644 --- a/test/mooncake/factorizations.jl +++ b/test/mooncake/factorizations.jl @@ -8,6 +8,11 @@ using MatrixAlgebraKit: remove_qr_gauge_dependence!, remove_lq_gauge_dependence! using Mooncake using Random +function call_and_zero!(f!, A, alg) + F′ = f!(A, alg) + MatrixAlgebraKit.zero!(A) + return F′ +end mode = Mooncake.ReverseMode rng = Random.default_rng() @@ -18,7 +23,6 @@ eltypes = (Float64, ComplexF64) @timedtestset "Mooncake - Factorizations: $(TensorKit.type_repr(sectortype(eltype(V)))) ($T)" for V in spacelist, T in eltypes atol = default_tol(T) rtol = default_tol(T) - @timedtestset "QR" begin A = randn(T, V[1] ⊗ V[2] ← V[1] ⊗ V[2]) @@ -29,8 +33,7 @@ eltypes = (Float64, ComplexF64) ΔQR = Mooncake.randn_tangent(rng, QR) remove_qr_gauge_dependence!(ΔQR..., A, QR...) Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false) - # TODO: - # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) + #Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) A = randn(T, V[1] ⊗ V[2] ⊗ V[3] ← (V[4] ⊗ V[5])') @@ -41,8 +44,7 @@ eltypes = (Float64, ComplexF64) ΔQR = Mooncake.randn_tangent(rng, QR) remove_qr_gauge_dependence!(ΔQR..., A, QR...) Mooncake.TestUtils.test_rule(rng, qr_full, A; output_tangent = ΔQR, atol, rtol, mode, is_primitive = false) - # TODO: - # Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) + #Mooncake.TestUtils.test_rule(rng, qr_null, A; atol, rtol, mode, is_primitive = false) end @timedtestset "LQ" begin @@ -50,25 +52,23 @@ eltypes = (Float64, ComplexF64) Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false) - # qr_full/qr_null requires being careful with gauges + # lq_full/lq_null requires being careful with gauges LQ = lq_full(A) ΔLQ = Mooncake.randn_tangent(rng, LQ) remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false) - # TODO: - # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) + #Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) A = randn(T, V[1] ⊗ V[2] ← (V[3] ⊗ V[4] ⊗ V[5])') Mooncake.TestUtils.test_rule(rng, lq_compact, A; atol, rtol, mode, is_primitive = false) - # qr_full/qr_null requires being careful with gauges + # lq_full/lq_null requires being careful with gauges LQ = lq_full(A) ΔLQ = Mooncake.randn_tangent(rng, LQ) remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) Mooncake.TestUtils.test_rule(rng, lq_full, A; output_tangent = ΔLQ, atol, rtol, mode, is_primitive = false) - # TODO: - # Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) + #Mooncake.TestUtils.test_rule(rng, lq_null, A; atol, rtol, mode, is_primitive = false) end @timedtestset "Eigenvalue decomposition" begin @@ -105,6 +105,15 @@ eltypes = (Float64, ComplexF64) ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc))) remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...) Mooncake.TestUtils.test_rule(rng, svd_trunc, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode) + + V_trunc = spacetype(t)(c => min(size(b)...) ÷ 2 for (c, b) in blocks(t)) + trunc = truncspace(V_trunc) + USVᴴ = svd_compact(t) + alg = MatrixAlgebraKit.select_algorithm(svd_trunc, t, nothing; trunc) + USVᴴtrunc = svd_trunc(t, alg) + ΔUSVᴴtrunc = (Mooncake.randn_tangent(rng, Base.front(USVᴴtrunc))..., zero(last(USVᴴtrunc))) + remove_svd_gauge_dependence!(ΔUSVᴴtrunc[1], ΔUSVᴴtrunc[3], Base.front(USVᴴtrunc)...) + Mooncake.TestUtils.test_rule(rng, call_and_zero!, svd_trunc!, t, alg; output_tangent = ΔUSVᴴtrunc, atol, rtol, mode, is_primitive = false) end end end