Automatic differentiation
There are two ways to compute the gradient of an einsum expression. The first one is to use the OMEinsum
package, which is a custom implementation of the reverse-mode automatic differentiation. The second one is to use the Zygote
package, which is a source-to-source automatic differentiation tool.
Built-in automatic differentiation
The OMEinsum
package provides a built-in function cost_and_gradient
to compute the cost and the gradient of an einsum expression.
julia> using OMEinsum # the 1st way
julia> A, B, C = randn(2, 3), randn(3, 4), randn(4, 2);
julia> y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))
(fill(1.179265566367559), Any[[-1.7951011334692109 3.1568791630975785 -0.7088567341821649; 0.24353904967354187 -0.0570871659434461 2.434662615292539], [1.292994337464655 -1.6329686299727988 2.017180115922347 -0.13917260765943967; 0.513230620404249 0.5663549420775839 -2.672728243394723 1.4493261897497107; 1.0710675686508204 -1.630389454337067 2.465143322180202 -0.4593014592888689], [-2.2142263331962093 -1.2990422150146306; 2.175533694913128 7.054662160407904; 1.5580633515081406 1.785562570968924; -1.2448724752878155 -1.5254303206361648]])
This built-in automatic differentiation is designed for tensor contractions and is more efficient than the general-purpose automatic differentiation tools.
For complex valued tensors, the automatic differentiation is defined in a convention that treat the real and imaginary parts as independent variables.
Using Zygote
The backward rule for the basic einsum operation is ported to the ChainRulesCore
, which is used by the Zygote
package. Zygote is a source-to-source automatic differentiation tool that can be used to compute the gradient of an einsum expression. It is more general and can be used for any Julia code.
julia> using Zygote # the 2nd way
julia> Zygote.gradient((A, B, C)->ein"(ij, jk), ki->"(A, B, C)[], A, B, C)
([-1.7951011334692109 3.1568791630975785 -0.7088567341821649; 0.24353904967354187 -0.0570871659434461 2.434662615292539], [1.292994337464655 -1.6329686299727988 2.017180115922347 -0.13917260765943967; 0.513230620404249 0.5663549420775839 -2.672728243394723 1.4493261897497107; 1.0710675686508204 -1.630389454337067 2.465143322180202 -0.4593014592888689], [-2.2142263331962093 -1.2990422150146306; 2.175533694913128 7.054662160407904; 1.5580633515081406 1.785562570968924; -1.2448724752878155 -1.5254303206361648])