Automatic differentiation

There are two ways to compute the gradient of an einsum expression. The first one is to use the OMEinsum package, which is a custom implementation of the reverse-mode automatic differentiation. The second one is to use the Zygote package, which is a source-to-source automatic differentiation tool.

Built-in automatic differentiation

The OMEinsum package provides a built-in function cost_and_gradient to compute the cost and the gradient of an einsum expression.

julia> using OMEinsum  # the 1st way
julia> A, B, C = randn(2, 3), randn(3, 4), randn(4, 2);
julia> y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))(fill(3.7094240444327355), Any[[2.331132377789704 0.3604625399389822 1.4076411849274724; -3.178648482861502 2.515894721286762 0.8097343771956073], [-0.8027850994441331 2.5620941795921874 -1.3477437068689984 -0.3790973461177588; -1.472624200757705 2.7316691697770628 1.9774859837804084 -3.023113183747034; -0.6881983430969775 1.926605351386717 -0.5454394831381844 -0.6440444449360182], [-3.5263887255114326 1.0447945836875157; -2.5087820253112456 -1.3341850183986637; 2.6086869328276063 -1.1679457738967522; 0.6541905155748717 0.23427950824126367]])

This built-in automatic differentiation is designed for tensor contractions and is more efficient than the general-purpose automatic differentiation tools.

Using Zygote

The backward rule for the basic einsum operation is ported to the ChainRulesCore, which is used by the Zygote package. Zygote is a source-to-source automatic differentiation tool that can be used to compute the gradient of an einsum expression. It is more general and can be used for any Julia code.

julia> using Zygote  # the 2nd way
julia> Zygote.gradient((A, B, C)->ein"(ij, jk), ki->"(A, B, C)[], A, B, C)([2.331132377789704 0.3604625399389822 1.4076411849274724; -3.178648482861502 2.515894721286762 0.8097343771956073], [-0.8027850994441331 2.5620941795921874 -1.3477437068689984 -0.3790973461177588; -1.472624200757705 2.7316691697770628 1.9774859837804084 -3.023113183747034; -0.6881983430969775 1.926605351386717 -0.5454394831381844 -0.6440444449360182], [-3.5263887255114326 1.0447945836875157; -2.5087820253112456 -1.3341850183986637; 2.6086869328276063 -1.1679457738967522; 0.6541905155748717 0.23427950824126367])