Input (flat)
An einsum specification should be given via the ein_str string-literal or with the @ein-macro as e.g.
julia> c = ein"ij,jk -> ik"(a,b)
julia> @ein c[i,k] := a[i,j] * b[j,k]where both specifications encode the same operation - a matrix multiplication. The ein_str-literal is parsed directly into an EinCode struct that holds the indices of the input ixs = (('i','j'),('j','k')) and output iy = ('i','k') as type parameters, making them accessible at compile time.
The string-literal form gets turned into
julia> c = EinCode((('i','j'),('j','k')),('i','k'))(a,b)Calling an EinCode-object gets lowered to
julia> c = einsum(EinCode((('i','j'),('j','k')),('i','k')), (a,b), size_dict = nothing)where nothing is the default argument for the (as of yet not used during specification) size_dict, which could allow to provide dimensions for index-labels that only appear in the output.
In the next step, a singleton-subtype of the abstract type EinRule is chosen which is later used for dispatch. Subtypes of EinRule specify the kind of operation and are created in such a way that they allow useful dispatch. They are defined in EinRule.jl.
The possible types are:
Identity- operation is the identity on one tensor, e.g.ein"ijk -> ijk"MatMul- operation is a matrix multiplication of two matrices, possibly with permutations of inputs and/or outputs, e.g.ein"ij,kj -> ik"Permutedims- operation is a permutation of the indices of one tensor, e.g.ein"ijk -> jki"Hadamard- operation is a hadamard-product of arbitrary many tensors, e.g.ein"ij,ij,ij -> ij"Tr- operation is a trace of one matrix, e.g.ein"ii ->"PTrace- operation is a partial trace of one tensor, e.g.ein"iij -> j"Sum- operation is a reduction over one or more indices of one tensor, e.g.ein"ijkl -> il"PairWise- operation is a tensor-contraction over arbitrary many tensors, e.g.ein"ijk,kl,lmn,no -> ijmo"DefaultRule- default if none of the above match, e.g.ein"ij,ik,il -> jkl"
Since ixs and iy are saved as type-parameters, the operation-matching can happen at compile time. The operation is chosen using match_rule(ixs,iy) by testing all subtypes of EinRule in the sequence above (top to bottom) and picking the first match.
This enables us to chose MatMul for a matrix multiplication which is also a legal tensor-contraction, i.e. a PairWise, assuming that we can have a lower-overhead implementation for MatMul than PairWise.
We proceed by calling einsum(<:EinRule, <:EinCode, xs, size_dict) which dispatches on the EinRule and the type of xs - the latter enables us to dispatch to e.g. cuda-specific routines for certain operations (as done in the cueinsum.jl file).
In the case of the matrix-multiplication above, einsum calls * which can dispatch to efficient routines for most Array-types including CuArray.
Input (Nested)
Whether with the ein_str string-literal or the @ein macro, nested expressions are mapped to a nested struct. Consider the example
julia> c = ein"(ij,jk),kl -> il"(a,b,c)
julia> @ein c[i,l] := (a[i,j] * b[j,k]) * c[k,l]which is a simply a product of three matrices evaluated as two matrix products in sequence.
This is equivalent to
julia> c = ein"ik,kl -> il"(ein"ij,jk -> ik"(a,b),c)
julia> @ein ab[i,k] := a[i,j] * b[j,k]
julia> @ein c[i,l] := ab[i,k] * c[k,l]and is expressed as a nested structure NestedEinsumStable which contains the EinCodes for the intermediate calculations as well as some logic to assign the correct input and output tensors to the correct EinCode.
NestedEinsumStable has the following definition:
struct NestedEinsumStable{T,S,N}
args::S
eins::T
endwhere the eins-field contains an EinCode of N arguments and args holds the arguments to that EinCode which can either be a integer to label a tensor or a NestedEinsumStable itself. The labeling works such that the ith input is represented by the number i.
Upon application to tensors, a NestedEinsumStable evaluates its arguments. If the argument is an integer i, the ith provided tensor is chosen, otherwise the NestedEinsumStable is evaluated.
To make it more concrete, consider the NestedEinsumStable for the expression above, where for easier reading the type signatures were removed and the EinCode-structs were replaced by ein-string literals.
julia> ein"(ij,jk),kl -> il"
NestedEinsumStable{...}((NestedEinsumStable{...}((1, 2), ein"ij,jk -> ik"), 3), ein"ik,kl -> il")Evaluating this expression with three arguments leads to the inner NestedEinsumStable to be evaluated first with the first and second argument and the specifiation ein"ij,jk -> ik". Then the result of that is given as the first argument to ein"ik,kl -> il" with the third argument as the second input.
To improve understanding, you might replace the integers with getindex operations in your head
ein"(ij,jk),kl -> il"(xs...)
โ NestedEinsumStable{...}((NestedEinsumStable{...}((xs[1], xs[2]), ein"ij,jk -> ik"), xs[3]), ein"ik,kl -> il")and finally turn it into
ein"(ij,jk),kl -> il"(xs...)
โ ein"ik,kl -> il"(ein"ij,jk -> ik"(xs[1], xs[2]), xs[3])