Skip to content

Projections do not play well with GPUCompiler #429

@p-zubieta

Description

@p-zubieta

Here's an example that does not play well with GPUCompiler: https://gist.github.com/pabloferz/1390d85383e3243015be7ad5b162bcc4

A possible, but probably incomplete fix discussed with @mcabbott, is having the following specializations:

function ProjectTo(x::AbstractArray{T}) where {T <: AbstractFloat}
    return ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
end

ProjectTo(x::AbstractArray{T}) where {T <: Bool} = ProjectTo{NoTangent}()

function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S}) where {S <: Number}
    T = ChainRulesCore.project_type(project.element)
    return S <: T ? dx : map(project.element, dx)
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    ProjectTorelated to the projection functionality

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions