Skip to content

Commit accc294

Browse files
KristofferCjrevels
authored andcommitted
fix wrong partials multiplied in FMA (#206)
1 parent 6c61b61 commit accc294

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

src/dual.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,13 +414,13 @@ end
414414
vx, vy = value(x), value(y)
415415
result = fma(vx, vy, value(z))
416416
return Dual(result,
417-
_mul_partials(partials(x), partials(y), vx, vy) + partials(z))
417+
_mul_partials(partials(x), partials(y), vy, vx) + partials(z))
418418
end
419419

420420
@inline function Base.fma(x::Dual, y::Dual, z::Real)
421421
vx, vy = value(x), value(y)
422422
result = fma(vx, vy, z)
423-
return Dual(result, _mul_partials(partials(x), partials(y), vx, vy))
423+
return Dual(result, _mul_partials(partials(x), partials(y), vy, vx))
424424
end
425425

426426
@inline function Base.fma(x::Dual, y::Real, z::Dual)

test/DualTest.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -387,20 +387,20 @@ for N in (0,3), M in (0,4), T in (Int, Float32)
387387

388388
@test partials(NaNMath.pow(Dual(-2.0, 1.0), Dual(2.0, 0.0)), 1) == -4.0
389389

390-
@test fma(FDNUM, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
390+
test_approx_diffnums(fma(FDNUM, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
391391
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS +
392-
PARTIALS3)
393-
@test fma(FDNUM, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
394-
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS)
395-
@test fma(PRIMAL, FDNUM2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
396-
PRIMAL*PARTIALS2 + PARTIALS3)
397-
@test fma(PRIMAL, FDNUM2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
398-
PRIMAL*PARTIALS2)
399-
@test fma(FDNUM, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
400-
PRIMAL2*PARTIALS + PARTIALS3)
401-
@test fma(FDNUM, PRIMAL2, PRIMAL3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
402-
PRIMAL2*PARTIALS)
403-
@test fma(PRIMAL, PRIMAL2, FDNUM3) == Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3)
392+
PARTIALS3))
393+
test_approx_diffnums(fma(FDNUM, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
394+
PRIMAL*PARTIALS2 + PRIMAL2*PARTIALS))
395+
test_approx_diffnums(fma(PRIMAL, FDNUM2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
396+
PRIMAL*PARTIALS2 + PARTIALS3))
397+
test_approx_diffnums(fma(PRIMAL, FDNUM2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
398+
PRIMAL*PARTIALS2))
399+
test_approx_diffnums(fma(FDNUM, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
400+
PRIMAL2*PARTIALS + PARTIALS3))
401+
test_approx_diffnums(fma(FDNUM, PRIMAL2, PRIMAL3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3),
402+
PRIMAL2*PARTIALS))
403+
test_approx_diffnums(fma(PRIMAL, PRIMAL2, FDNUM3), Dual(fma(PRIMAL, PRIMAL2, PRIMAL3), PARTIALS3))
404404

405405
# Unary Functions #
406406
#-----------------#

0 commit comments

Comments
 (0)