Skip to content

Commit df00ef8

Browse files
authored
Merge pull request #464 from dfdx/dfdx/literal-pow-2
Add rules for Base.literal_pow
2 parents 5116c95 + e3c5895 commit df00ef8

File tree

3 files changed

+19
-1
lines changed

3 files changed

+19
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.8.19"
3+
version = "0.8.20"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/base.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,15 @@ end
161161
@scalar_rule round(x) zero(x)
162162
@scalar_rule floor(x) zero(x)
163163
@scalar_rule ceil(x) zero(x)
164+
165+
# note: rules for ^ are defined in the fastmath_able.jl
166+
function frule((_, _, Δx, _), ::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
167+
y = Base.literal_pow(^, x, pv)
168+
return y, (p * y / x * Δx)
169+
end
170+
171+
function rrule(::typeof(Base.literal_pow), ::typeof(^), x::Real, pv::Val{p}) where p
172+
y = Base.literal_pow(^, x, pv)
173+
literal_pow_pullback(dy) = NoTangent(), NoTangent(), (p * y / x * dy), NoTangent()
174+
return y, literal_pow_pullback
175+
end

test/rulesets/Base/base.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,10 @@
176176
@test frule((NoTangent(), NoTangent(), NoTangent()), Base.depwarn, "message", :f) !== nothing
177177
@test rrule(Base.depwarn, "message", :f) !== nothing
178178
end
179+
180+
@testset "literal_pow" begin
181+
# for real x and n, x must be >0
182+
test_frule(Base.literal_pow, ^, 3.5, Val(3))
183+
test_rrule(Base.literal_pow, ^, 3.5, Val(3))
184+
end
179185
end

0 commit comments

Comments
 (0)