Automatic differentiation
There are two ways to compute the gradient of an einsum expression. The first one is to use the OMEinsum
package, which is a custom implementation of the reverse-mode automatic differentiation. The second one is to use the Zygote
package, which is a source-to-source automatic differentiation tool.
Built-in automatic differentiation
The OMEinsum
package provides a built-in function cost_and_gradient
to compute the cost and the gradient of an einsum expression.
julia> using OMEinsum # the 1st way
julia> A, B, C = randn(2, 3), randn(3, 4), randn(4, 2);
julia> y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))
(fill(-2.285080387194546), Any[[-0.1287165078710394 -1.4952116108960334 -1.4320995336124256; -0.2763738024859782 -3.748588239874915 3.1075298273021734], [-0.3679231691509567 0.9347196988148129 -0.9038021005108982 -0.09541223948800297; -2.413771589763518 0.47888994551366143 -0.0765616078621508 -0.09266005086696817; -1.149029124085593 0.34499543535110294 -0.157604488526492 -0.05514864444090727], [0.40504968989307905 0.9007131490674809; -0.44140704009810733 0.44558310403650636; -0.7738381621821229 -1.2301028276566892; -1.2162672969410402 1.0491354048170558]])
This built-in automatic differentiation is designed for tensor contractions and is more efficient than the general-purpose automatic differentiation tools.
For complex valued tensors, the automatic differentiation is defined in a convention that treat the real and imaginary parts as independent variables.
Using Zygote
The backward rule for the basic einsum operation is ported to the ChainRulesCore
, which is used by the Zygote
package. Zygote is a source-to-source automatic differentiation tool that can be used to compute the gradient of an einsum expression. It is more general and can be used for any Julia code.
julia> using Zygote # the 2nd way
julia> Zygote.gradient((A, B, C)->ein"(ij, jk), ki->"(A, B, C)[], A, B, C)
([-0.1287165078710394 -1.4952116108960334 -1.4320995336124256; -0.2763738024859782 -3.748588239874915 3.1075298273021734], [-0.3679231691509567 0.9347196988148129 -0.9038021005108982 -0.09541223948800297; -2.413771589763518 0.47888994551366143 -0.0765616078621508 -0.09266005086696817; -1.149029124085593 0.34499543535110294 -0.157604488526492 -0.05514864444090727], [0.40504968989307905 0.9007131490674809; -0.44140704009810733 0.44558310403650636; -0.7738381621821229 -1.2301028276566892; -1.2162672969410402 1.0491354048170558])