Automatic differentiation
There are two ways to compute the gradient of an einsum expression. The first one is to use the OMEinsum
package, which is a custom implementation of the reverse-mode automatic differentiation. The second one is to use the Zygote
package, which is a source-to-source automatic differentiation tool.
Built-in automatic differentiation
The OMEinsum
package provides a built-in function cost_and_gradient
to compute the cost and the gradient of an einsum expression.
julia> using OMEinsum # the 1st way
julia> A, B, C = randn(2, 3), randn(3, 4), randn(4, 2);
julia> y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))
(fill(-3.0740985717829927), Any[[-1.145631922466385 -3.6412646042015617 -2.4131954537582616; 2.313704109415201 2.7152604473606923 -0.6378827396086256], [0.7136718917075405 0.19898886806164376 -2.696642639510973 -0.23787327001593733; -0.11522912599559562 0.8642911403267155 0.004109286548217061 2.410127039880415; 0.17046791822882398 -0.6773770551531757 -0.2953503297057019 -1.9747569249632837], [0.9249940435875346 -0.12331059253724527; 1.6019299109723968 1.9114953577700302; 0.7204157447891663 -0.26053529750997073; -0.8646799269381289 0.12003104025808471]])
This built-in automatic differentiation is designed for tensor contractions and is more efficient than the general-purpose automatic differentiation tools.
For complex valued tensors, the automatic differentiation is defined in a convention that treat the real and imaginary parts as independent variables.
Using Zygote
The backward rule for the basic einsum operation is ported to the ChainRulesCore
, which is used by the Zygote
package. Zygote is a source-to-source automatic differentiation tool that can be used to compute the gradient of an einsum expression. It is more general and can be used for any Julia code.
julia> using Zygote # the 2nd way
julia> Zygote.gradient((A, B, C)->ein"(ij, jk), ki->"(A, B, C)[], A, B, C)
([-1.145631922466385 -3.6412646042015617 -2.4131954537582616; 2.313704109415201 2.7152604473606923 -0.6378827396086256], [0.7136718917075405 0.19898886806164376 -2.696642639510973 -0.23787327001593733; -0.11522912599559562 0.8642911403267155 0.004109286548217061 2.410127039880415; 0.17046791822882398 -0.6773770551531757 -0.2953503297057019 -1.9747569249632837], [0.9249940435875346 -0.12331059253724527; 1.6019299109723968 1.9114953577700302; 0.7204157447891663 -0.26053529750997073; -0.8646799269381289 0.12003104025808471])