Automatic differentiation

There are two ways to compute the gradient of an einsum expression. The first one is to use the OMEinsum package, which is a custom implementation of the reverse-mode automatic differentiation. The second one is to use the Zygote package, which is a source-to-source automatic differentiation tool.

Built-in automatic differentiation

The OMEinsum package provides a built-in function cost_and_gradient to compute the cost and the gradient of an einsum expression.

julia> using OMEinsum  # the 1st way
julia> A, B, C = randn(2, 3), randn(3, 4), randn(4, 2);
julia> y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))(fill(1.7006481335554082), Any[[1.8685989855045142 0.019612588283935525 -0.827769626404614; -1.3083365397142133 1.6092200992442012 -1.4285206201565757], [-0.47171467426358654 0.7324493209867919 0.31193470105469 0.5756264437925518; 0.012423887679033793 0.02285436899231847 0.047503890146007706 0.02989709495081763; 0.5369177010421136 -0.7954754487420234 -0.3045260591594871 -0.6143347297597604], [0.1275205430051164 0.09268021980172801; 1.716316160795448 1.231796345590306; 0.30010618998434435 0.22713875454190335; -0.5285097360687832 -0.38580153264721406]])

This built-in automatic differentiation is designed for tensor contractions and is more efficient than the general-purpose automatic differentiation tools.

Using Zygote

The backward rule for the basic einsum operation is ported to the ChainRulesCore, which is used by the Zygote package. Zygote is a source-to-source automatic differentiation tool that can be used to compute the gradient of an einsum expression. It is more general and can be used for any Julia code.

julia> using Zygote  # the 2nd way
julia> Zygote.gradient((A, B, C)->ein"(ij, jk), ki->"(A, B, C)[], A, B, C)([1.8685989855045142 0.019612588283935525 -0.827769626404614; -1.3083365397142133 1.6092200992442012 -1.4285206201565757], [-0.47171467426358654 0.7324493209867919 0.31193470105469 0.5756264437925518; 0.012423887679033793 0.02285436899231847 0.047503890146007706 0.02989709495081763; 0.5369177010421136 -0.7954754487420234 -0.3045260591594871 -0.6143347297597604], [0.1275205430051164 0.09268021980172801; 1.716316160795448 1.231796345590306; 0.30010618998434435 0.22713875454190335; -0.5285097360687832 -0.38580153264721406])