Automatic differentiation
There are two ways to compute the gradient of an einsum expression. The first one is to use the OMEinsum
package, which is a custom implementation of the reverse-mode automatic differentiation. The second one is to use the Zygote
package, which is a source-to-source automatic differentiation tool.
Built-in automatic differentiation
The OMEinsum
package provides a built-in function cost_and_gradient
to compute the cost and the gradient of an einsum expression.
julia> using OMEinsum # the 1st way
julia> A, B, C = randn(2, 3), randn(3, 4), randn(4, 2);
julia> y, g = cost_and_gradient(ein"(ij, jk), ki->", (A, B, C))
(fill(2.5097942737942214), Any[[-0.6998192665853843 -0.1811199040419234 -1.6483673457060808; 0.5388969312121031 -0.331607580302559 3.9087828303058174], [0.4700588343337228 0.7781534940810114 2.9392091631698163 0.25367031081450486; 0.05858440070277229 -1.038360221500565 -3.346868751724879 -0.16225830700140417; -0.3349094855874679 -0.792788193489649 -2.8737267481785302 -0.22144002629577128], [-4.927428477234182 -0.44416249575655825; -1.3973895876787197 2.0802628141178205; 2.1787199879515606 -1.412072724774909; 0.18538804529177405 0.6388094174180385]])
This built-in automatic differentiation is designed for tensor contractions and is more efficient than the general-purpose automatic differentiation tools.
For complex valued tensors, the automatic differentiation is defined in a convention that treat the real and imaginary parts as independent variables.
Using Zygote
The backward rule for the basic einsum operation is ported to the ChainRulesCore
, which is used by the Zygote
package. Zygote is a source-to-source automatic differentiation tool that can be used to compute the gradient of an einsum expression. It is more general and can be used for any Julia code.
julia> using Zygote # the 2nd way
julia> Zygote.gradient((A, B, C)->ein"(ij, jk), ki->"(A, B, C)[], A, B, C)
([-0.6998192665853843 -0.1811199040419234 -1.6483673457060808; 0.5388969312121031 -0.331607580302559 3.9087828303058174], [0.4700588343337228 0.7781534940810114 2.9392091631698163 0.25367031081450486; 0.05858440070277229 -1.038360221500565 -3.346868751724879 -0.16225830700140417; -0.3349094855874679 -0.792788193489649 -2.8737267481785302 -0.22144002629577128], [-4.927428477234182 -0.44416249575655825; -1.3973895876787197 2.0802628141178205; 2.1787199879515606 -1.412072724774909; 0.18538804529177405 0.6388094174180385])