Skip to content

rrule for Broadcast.broadcasted fails when broadcasted over Tuple{} #830

@AayushSabharwal

Description

@AayushSabharwal

The MWE explains it better than text:

jula> using DifferentiationInterface, Zygote, ADTypes
julia> DifferentiationInterface.jacobian(AutoZygote(), ones(2)) do e
           2e .+ sum(isone.(()); init = 0.0)
       end

Errors with

Tuple field type cannot be Union{}
Stacktrace:
  [1] may_bc_derivatives(::Type{Union{}}, f::typeof(isone), args::Tuple{})
    @ ChainRules ~/.julia/packages/ChainRules/14CDN/src/rulesets/Base/broadcast.jl:51
  [2] rrule(cfg::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.Style{Tuple}, f::typeof(isone), args::Tuple{})
    @ ChainRules ~/.julia/packages/ChainRules/14CDN/src/rulesets/Base/broadcast.jl:36
  [3] chain_rrule
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/chainrules.jl:234 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0 [inlined]
  [5] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::Base.Broadcast.Style{Tuple}, ::typeof(isone), ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:81
  [6] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
  [7] adjoint
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:211 [inlined]
  [8] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
  [9] broadcasted
    @ ./broadcast.jl:1341 [inlined]
 [10] #43
    @ ./REPL[25]:2 [inlined]
 [11] _pullback(ctx::Zygote.Context{false}, f::var"#43#44", args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [12] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [13] adjoint
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:211 [inlined]
 [14] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [15] call_composed
    @ ./operators.jl:1045 [inlined]
 [16] call_composed
    @ ./operators.jl:1044 [inlined]
 [17] #_#103
    @ ./operators.jl:1041 [inlined]
 [18] _pullback(::Zygote.Context{false}, ::Base.var"##_#103", ::@Kwargs{}, ::ComposedFunction{typeof(Zygote._jvec), var"#43#44"}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [19] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [20] adjoint
    @ ~/.julia/packages/Zygote/55SqB/src/lib/lib.jl:211 [inlined]
 [21] _pullback
    @ ~/.julia/packages/ZygoteRules/CkVIK/src/adjoint.jl:67 [inlined]
 [22] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [23] _pullback(ctx::Zygote.Context{false}, f::ComposedFunction{typeof(Zygote._jvec), var"#43#44"}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface2.jl:0
 [24] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:96
 [25] pullback
    @ ~/.julia/packages/Zygote/55SqB/src/compiler/interface.jl:94 [inlined]
 [26] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/55SqB/src/lib/grad.jl:181
 [27] jacobian
    @ ~/.julia/packages/Zygote/55SqB/src/lib/grad.jl:168 [inlined]
 [28] jacobian
    @ ~/.julia/packages/DifferentiationInterface/MgcE4/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl:165 [inlined]
 [29] jacobian(::var"#43#44", ::AutoZygote, ::Vector{Float64})
    @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/MgcE4/src/first_order/jacobian.jl:100
 [30] top-level scope
    @ REPL[25]:1
 [31] top-level scope
    @ none:1

The issue is

T = Broadcast.combine_eltypes(f, args)
if T === Bool # TODO use nondifftype here
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
@debug("split broadcasting trivial", f, T)
bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...)
return f.(args...), bc_trivial_back
elseif T <: Number && may_bc_derivatives(T, f, args...)

T is Union{}, which passes T <: Number and errors inside may_bc_derivatives which tries to create a Tuple{Union{}, ...}

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions