Automatic differentiation

There are two ways to compute the gradient of an einsum expression. The first one is to use the OMEinsum package, which is a custom implementation of the reverse-mode automatic differentiation. The second one is to use the Zygote package, which is a source-to-source automatic differentiation tool.

Built-in automatic differentiation

The OMEinsum package provides a built-in function cost_and_gradient to compute the cost and the gradient of an einsum expression.

julia> using OMEinsum  # the 1st way
julia> A, B, C = randn(2, 3), randn(3, 4), randn(4, 2);
julia> y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))(fill(7.095515127645953), Any[[2.2901653789533842 -2.3429320300550835 0.8036763779307871; 1.573725197540953 -2.534699918050543 2.7690044617186618], [1.6178814719113452 5.62624833777079 2.7289459558325633 0.3990272660043035; -0.5833004955760545 -1.9891272585284658 -0.9575415964828057 -0.17216228542147055; -0.14187829809762087 -1.6318124658497102 -1.0017066521287445 0.7842805610161577], [1.2639722302537375 1.6458456796197984; 3.372970087617856 2.123428730429472; -1.8565008148374564 -1.6891516873398387; 1.4555729097690462 2.154756302004321]])

This built-in automatic differentiation is designed for tensor contractions and is more efficient than the general-purpose automatic differentiation tools.

Using Zygote

The backward rule for the basic einsum operation is ported to the ChainRulesCore, which is used by the Zygote package. Zygote is a source-to-source automatic differentiation tool that can be used to compute the gradient of an einsum expression. It is more general and can be used for any Julia code.

julia> using Zygote  # the 2nd way
julia> Zygote.gradient((A, B, C)->ein"(ij, jk), ki->"(A, B, C)[], A, B, C)([2.2901653789533842 -2.3429320300550835 0.8036763779307871; 1.573725197540953 -2.534699918050543 2.7690044617186618], [1.6178814719113452 5.62624833777079 2.7289459558325633 0.3990272660043035; -0.5833004955760545 -1.9891272585284658 -0.9575415964828057 -0.17216228542147055; -0.14187829809762087 -1.6318124658497102 -1.0017066521287445 0.7842805610161577], [1.2639722302537375 1.6458456796197984; 3.372970087617856 2.123428730429472; -1.8565008148374564 -1.6891516873398387; 1.4555729097690462 2.154756302004321])