Automatic differentiation

There are two ways to compute the gradient of an einsum expression. The first one is to use the OMEinsum package, which is a custom implementation of the reverse-mode automatic differentiation. The second one is to use the Zygote package, which is a source-to-source automatic differentiation tool.

Built-in automatic differentiation

The OMEinsum package provides a built-in function cost_and_gradient to compute the cost and the gradient of an einsum expression.

julia> using OMEinsum  # the 1st way
julia> A, B, C = randn(2, 3), randn(3, 4), randn(4, 2);
julia> y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))(fill(-2.1377526189087814), Any[[0.5063370937562979 0.005581259151837478 -1.0457426048982317; -1.386773097810222 -1.189564517546563 -0.7428377220431881], [2.2863240614394678 -0.446536407987949 -1.764945042725527 0.8570908967967413; -0.2518013955005649 0.9526258847254069 0.2450053082918978 -0.23708152456992845; -4.1804176347523425 1.5673819858049614 3.2691828746541076 -1.6857400395701023], [-0.30392375042888575 -1.2434442105842511; -1.1780528332784956 -1.2997585175928115; -2.3167478409194517 -0.7500308631964716; 0.7746605628128713 1.4575260650922275]])

This built-in automatic differentiation is designed for tensor contractions and is more efficient than the general-purpose automatic differentiation tools.

For complex valued tensors, the automatic differentiation is defined in a convention that treat the real and imaginary parts as independent variables.

Using Zygote

The backward rule for the basic einsum operation is ported to the ChainRulesCore, which is used by the Zygote package. Zygote is a source-to-source automatic differentiation tool that can be used to compute the gradient of an einsum expression. It is more general and can be used for any Julia code.

julia> using Zygote  # the 2nd way
julia> Zygote.gradient((A, B, C)->ein"(ij, jk), ki->"(A, B, C)[], A, B, C)([0.5063370937562979 0.005581259151837478 -1.0457426048982317; -1.386773097810222 -1.189564517546563 -0.7428377220431881], [2.2863240614394678 -0.446536407987949 -1.764945042725527 0.8570908967967413; -0.2518013955005649 0.9526258847254069 0.2450053082918978 -0.23708152456992845; -4.1804176347523425 1.5673819858049614 3.2691828746541076 -1.6857400395701023], [-0.30392375042888575 -1.2434442105842511; -1.1780528332784956 -1.2997585175928115; -2.3167478409194517 -0.7500308631964716; 0.7746605628128713 1.4575260650922275])