Automatic differentiation
There are two ways to compute the gradient of an einsum expression. The first one is to use the OMEinsum
package, which is a custom implementation of the reverse-mode automatic differentiation. The second one is to use the Zygote
package, which is a source-to-source automatic differentiation tool.
Built-in automatic differentiation
The OMEinsum
package provides a built-in function cost_and_gradient
to compute the cost and the gradient of an einsum expression.
julia> using OMEinsum # the 1st way
julia> A, B, C = randn(2, 3), randn(3, 4), randn(4, 2);
julia> y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))
(fill(-2.1377526189087814), Any[[0.5063370937562979 0.005581259151837478 -1.0457426048982317; -1.386773097810222 -1.189564517546563 -0.7428377220431881], [2.2863240614394678 -0.446536407987949 -1.764945042725527 0.8570908967967413; -0.2518013955005649 0.9526258847254069 0.2450053082918978 -0.23708152456992845; -4.1804176347523425 1.5673819858049614 3.2691828746541076 -1.6857400395701023], [-0.30392375042888575 -1.2434442105842511; -1.1780528332784956 -1.2997585175928115; -2.3167478409194517 -0.7500308631964716; 0.7746605628128713 1.4575260650922275]])
This built-in automatic differentiation is designed for tensor contractions and is more efficient than the general-purpose automatic differentiation tools.
For complex valued tensors, the automatic differentiation is defined in a convention that treat the real and imaginary parts as independent variables.
Using Zygote
The backward rule for the basic einsum operation is ported to the ChainRulesCore
, which is used by the Zygote
package. Zygote is a source-to-source automatic differentiation tool that can be used to compute the gradient of an einsum expression. It is more general and can be used for any Julia code.
julia> using Zygote # the 2nd way
julia> Zygote.gradient((A, B, C)->ein"(ij, jk), ki->"(A, B, C)[], A, B, C)
([0.5063370937562979 0.005581259151837478 -1.0457426048982317; -1.386773097810222 -1.189564517546563 -0.7428377220431881], [2.2863240614394678 -0.446536407987949 -1.764945042725527 0.8570908967967413; -0.2518013955005649 0.9526258847254069 0.2450053082918978 -0.23708152456992845; -4.1804176347523425 1.5673819858049614 3.2691828746541076 -1.6857400395701023], [-0.30392375042888575 -1.2434442105842511; -1.1780528332784956 -1.2997585175928115; -2.3167478409194517 -0.7500308631964716; 0.7746605628128713 1.4575260650922275])