Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Expand All @@ -31,6 +33,8 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
TensorKitAMDGPUExt = "AMDGPU"
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitEnzymeExt = "Enzyme"
TensorKitEnzymeTestUtilsExt = "EnzymeTestUtils"
TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitMooncakeExt = "Mooncake"

Expand All @@ -43,19 +47,21 @@ AMDGPU = "2"
CUDA = "6"
ChainRulesCore = "1"
Dictionaries = "0.4"
Enzyme = "0.13.157"
EnzymeTestUtils = "0.2.7"
FiniteDifferences = "0.12"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.7"
MatrixAlgebraKit = "0.6.8"
Mooncake = "0.5.27"
OhMyThreads = "0.8.0"
Printf = "1"
Random = "1"
ScopedValues = "1.3.0"
Strided = "2"
TensorKitSectors = "0.3.7"
TensorOperations = "5.5"
TensorOperations = "5.5.2"
TupleTools = "1.5"
VectorInterface = "0.4.8, 0.5, 0.6"
VectorInterface = "0.4.8, 0.5"
cuTENSOR = "6"
julia = "1.10"
16 changes: 16 additions & 0 deletions ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module TensorKitEnzymeExt

using Enzyme
using TensorKit
import TensorKit as TK
using VectorInterface
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using MatrixAlgebraKit
using TupleTools
using Random: AbstractRNG

include("utility.jl")
include("tensoroperations.jl")

end
208 changes: 208 additions & 0 deletions ext/TensorKitEnzymeExt/tensoroperations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# tensorcontract!
# ---------------
# TODO: it might be beneficial to compare here if it would make sense to simply compute the
# rrule of permute-permute-gemm-permute, rather than using the contractions directly.
# This could possibly out save some permutations being carried out twice, at the cost of having
# to store some more intermediate objects.
# For example, the combination `ΔC, pΔC, false` appears in the pullback for ΔA and ΔB, so effectively
# this permutation is done multiple times.

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.blas_contract!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
pA::Const{<:Index2Tuple},
B::Annotation{<:AbstractTensorMap},
pB::Const{<:Index2Tuple},
pAB::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
allocator::Const
) where {RT}
Ccache = isa(β, Const) ? nothing : copy(C.val)
A_needs_cache = EnzymeRules.overwritten(config)[3] && !(typeof(B) <: Const) && !(typeof(C) <: Const)
Acache = A_needs_cache ? copy(A.val) : nothing
B_needs_cache = EnzymeRules.overwritten(config)[5] && !(typeof(A) <: Const) && !(typeof(C) <: Const)
Bcache = B_needs_cache ? copy(B.val) : nothing
AB = if !isa(α, Const)
AB = TO.tensorcontract(A.val, pA.val, false, B.val, pB.val, false, pAB.val, One(), backend.val, allocator.val)
add!(C.val, AB, α.val, β.val)
AB
else
TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val)
nothing
end
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
cache = (Ccache, Acache, Bcache, AB)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.blas_contract!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
pA::Const{<:Index2Tuple},
B::Annotation{<:AbstractTensorMap},
pB::Const{<:Index2Tuple},
pAB::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
allocator::Const
) where {RT}
cacheC, cacheA, cacheB, AB = cache
Cval = cacheC
Aval = something(cacheA, A.val)
Bval = something(cacheB, B.val)

Δα = pullback_dα(α, C, AB)
Δβ = pullback_dβ(β, C, Cval)

if !isa(A, Const)
TensorKit.blas_contract_pullback_ΔA!(
A.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val
) # this typically returns nothing
end
if !isa(B, Const)
TensorKit.blas_contract_pullback_ΔB!(
B.dval, C.dval, Aval, pA.val, Bval, pB.val, pAB.val, α.val, backend.val, allocator.val
) # this typically returns nothing
end
!isa(C, Const) && pullback_dC!(C.dval, β.val) # this typically returns nothing
return nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, nothing, nothing
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(TensorKit.blas_contract!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
pA::Annotation{<:Index2Tuple},
B::Annotation{<:AbstractTensorMap},
pB::Annotation{<:Index2Tuple},
pAB::Annotation{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
allocator::Const
) where {RT}
# ΔC′ = ΔC*β + C*Δβ + A*B*Δα + ΔA*B*α + A*ΔB*α
if !isa(C, Const)
if isa(β, Const)
scale!(C.dval, β.val)
else
add!(C.dval, C.val, β.dval, β.val)
end
!isa(α, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.val, pB.val, pAB.val, α.dval, One(), backend.val, allocator.val)
!isa(A, Const) && TensorKit.blas_contract!(C.dval, A.dval, pA.val, B.val, pB.val, pAB.val, α.val, One(), backend.val, allocator.val)
!isa(B, Const) && TensorKit.blas_contract!(C.dval, A.val, pA.val, B.dval, pB.val, pAB.val, α.val, One(), backend.val, allocator.val)
end
TensorKit.blas_contract!(C.val, A.val, pA.val, B.val, pB.val, pAB.val, α.val, β.val, backend.val, allocator.val)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return C
elseif EnzymeRules.needs_primal(config)
return C.val
elseif EnzymeRules.needs_shadow(config)
return C.dval
else
return nothing
end
end

# tensortrace!
# ------------

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.trace_permute!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
q::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
) where {RT}
C_cache = !isa(β, Const) ? copy(C.val) : nothing
A_cache = EnzymeRules.overwritten(config)[3] ? copy(A.val) : nothing
At = if !isa(α, Const)
At = TO.tensortrace(A.val, p.val, q.val, false, One(), backend.val)
add!(C.val, At, α.val, β.val)
At
else
TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val)
nothing
end
primal = EnzymeRules.needs_primal(config) ? C.val : nothing
shadow = EnzymeRules.needs_shadow(config) ? C.dval : nothing
cache = (C_cache, A_cache, At)
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
end


function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorKit.trace_permute!)},
::Type{RT},
cache,
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Const{<:Index2Tuple},
q::Const{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
) where {RT}
C_cache, A_cache, At = cache
Aval = something(A_cache, A.val)
Cval = something(C_cache, C.val)
!isa(A, Const) && !isa(C, Const) && TensorKit.trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val)
Δαr = pullback_dα(α, C, At)
Δβr = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)
return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing
end

function EnzymeRules.forward(
config::EnzymeRules.FwdConfigWidth{1},
func::Const{typeof(TensorKit.trace_permute!)},
::Type{RT},
C::Annotation{<:AbstractTensorMap},
A::Annotation{<:AbstractTensorMap},
p::Annotation{<:Index2Tuple},
q::Annotation{<:Index2Tuple},
α::Annotation{<:Number},
β::Annotation{<:Number},
backend::Const,
) where {RT}
# dD = dα * tr(A) + α * tr(dA) + dβ * C + β * dC
# dC1 = dβ * C + β * dC
if !isa(C, Const)
if isa(β, Const)
scale!(C.dval, β.val)
else
add!(C.dval, C.val, β.dval, β.val)
end
!isa(α, Const) && TensorKit.trace_permute!(C.dval, A.val, p.val, q.val, α.dval, One(), backend.val)
!isa(A, Const) && TensorKit.trace_permute!(C.dval, A.dval, p.val, q.val, α.val, One(), backend.val)
end
TensorKit.trace_permute!(C.val, A.val, p.val, q.val, α.val, β.val, backend.val)
if EnzymeRules.needs_primal(config) && EnzymeRules.needs_shadow(config)
return C
elseif EnzymeRules.needs_primal(config)
return C.val
elseif EnzymeRules.needs_shadow(config)
return C.dval
else
return nothing
end
end
80 changes: 80 additions & 0 deletions ext/TensorKitEnzymeExt/utility.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Projection
# ----------
pullback_dα(α::Const, C::Const, A) = nothing
pullback_dα(α::Const, C::Annotation, A) = nothing
pullback_dα(α::Annotation, C::Const, A) = zero(α.val)
pullback_dα(α::Annotation, C::Annotation, A) = project_scalar(α.val, inner(A, C.dval))

pullback_dβ(β::Const, C::Const, Ccache) = nothing
pullback_dβ(β::Const, C::Annotation, Ccache) = nothing
pullback_dβ(β::Annotation, C::Const, Ccache) = zero(β.val)
pullback_dβ(β::Annotation, C::Annotation, Ccache) = project_scalar(β.val, inner(Ccache, C.dval))

pullback_dC!(ΔC, β::Number) = scale!(ΔC, conj(β))

"""
project_scalar(x::Number, dx::Number)

Project a computed tangent `dx` onto the correct tangent type for `x`.
For example, we might compute a complex `dx` but only require the real part.
"""
project_scalar(x::Number, dx::Number) = oftype(x, dx)
project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx))

# in-place multiplication and accumulation which might project to (real)
# TODO: this could probably be done without allocating
function project_mul!(C, A, B, α)
TC = TO.promote_contract(scalartype(A), scalartype(B), scalartype(α))
return if !(TC <: Real) && scalartype(C) <: Real
add!(C, real(mul!(zerovector(C, TC), A, B, α)))
else
mul!(C, A, B, α, One())
end
end
function project_contract!(C, A, pA, conjA, B, pB, conjB, pAB, α, backend, allocator)
TA = TensorKit.promote_permute(A)
TB = TensorKit.promote_permute(B)
TC = TO.promote_contract(TA, TB, scalartype(α))

return if scalartype(C) <: Real && !(TC <: Real)
add!(C, real(TO.tensorcontract!(zerovector(C, TC), A, pA, conjA, B, pB, conjB, pAB, α, Zero(), backend, allocator)))
else
TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, One(), backend, allocator)
end
end

# IndexTuple utility
# ------------------
trivtuple(N) = ntuple(identity, N)

Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
return TupleTools.getindices(p, trivtuple(N₁)),
TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁)
end
Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int)
return _repartition(linearize(p), N₁)
end
function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
return _repartition(p, N₁)
end
function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap)
return _repartition(p, TensorKit.numout(t))
end

# Ignore derivatives
# ------------------

@inline EnzymeRules.inactive_type(::Type{<:TensorKit.FusionTree}) = true
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.GenericTreeTransformer}) = true
@inline EnzymeRules.inactive_type(::Type{<:TensorKit.VectorSpace}) = true

@inline EnzymeRules.inactive(::typeof(TensorKit.sectorstructure), ::Any) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.degeneracystructure), ::Any) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.select), s::HomSpace, i::Index2Tuple) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.flip), s::HomSpace, i::Any) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.permute), s::HomSpace, i::Index2Tuple) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.braid), s::HomSpace, i::Index2Tuple, ::IndexTuple) = nothing
@inline EnzymeRules.inactive(::typeof(TensorKit.compose), s1::HomSpace, s2::HomSpace) = nothing
@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorcontract), c::HomSpace, p::Index2Tuple, α::Bool, b::HomSpace, q::Index2Tuple, β::Bool, pq::Index2Tuple) = nothing
Loading
Loading