Arithmetics and Broadcasting
As a tensor toolkit, many basic arithmetics are available in RSTSR.
We will touch arithmetics only in this section, and will mention computations based on mapping in next section.
1. Examples of Arithmetics Operations
RSTSR can handle +
, -
, *
, /
operations:
#![allow(unused)] fn main() { let a = rt::arange(5.0); let b = rt::arange(5.0) + 1.0; let c = &a + &b; println!("{:}", c); // output: [ 1 3 5 7 9] let d = &a / &b; println!("{:6.3}", d); // output: [ 0.000 0.500 0.667 0.750 0.800] let e = 2.0 * &a; println!("{:}", e); // output: [ 0 2 4 6 8] }
RSTSR can handle matmul operations by operator %
(matrix-matrix, matrix-vector or vector-vector inner dot, and has been optimized in some devices such as DeviceFaer
):
#![allow(unused)] fn main() { let mat = rt::arange(12).into_shape([3, 4]); let vec = rt::arange(4).into_shape([4]); // matrix multiplication let res = &mat % mat.t(); println!("{:3}", res); // output: // [[ 14 38 62] // [ 38 126 214] // [ 62 214 366]] // matrix-vector multiplication let res = &mat % &vec; println!("{:}", res); // output: [ 14 38 62] // vector-matrix multiplication let res = &vec % &mat.t(); println!("{:}", res); // output: [ 14 38 62] // vector inner dot let res = &vec % &vec; println!("{:}", res); // output: 14 }
For some special cases, bit operations and shift are also available:
#![allow(unused)] fn main() { let a = rt::asarray(vec![true, true, false, false]); let b = rt::asarray(vec![true, false, true, false]); // bitwise xor let c = a ^ b; println!("{:?}", c); // output: [false true true false] let a = rt::asarray(vec![9, 7, 5, 3]); let b = rt::asarray(vec![5, 6, 7, 8]); // shift left let c = a << b; println!("{:?}", c); // output: [ 288 448 640 768] }
The aforementioned examples should have coverd most usages of tensor arithmetics. The following document in this section will cover some advanced topics.
2. Overrided Operator %
We have already shown that %
is the operator for matrix multiplication. This is RSTSR specific usage.
This may cause some confusion, and we will discuss this topic.
Firstly, we follow convention of numpy that *
will always be elementwise multiply, similar to +
, and it does not give matrix multiplication or vector inner dot.
#![allow(unused)] fn main() { let mat = rt::arange(12).into_shape([3, 4]); let vec = rt::arange(4); // element-wise matrix multiplication let c = &mat * &mat; println!("{:3}", c); // output: // [[ 0 1 4 9] // [ 16 25 36 49] // [ 64 81 100 121]] // element-wise matrix-vector multiplication (broadcasting involved) let d = &mat * &vec; println!("{:2}", d); // output: // [[ 0 1 4 9] // [ 0 5 12 21] // [ 0 9 20 33]] // element-wise vector multiplication let e = &vec * &vec; println!("{:}", e); // output: [ 0 1 4 9] }
Numpy introduces @
notation for matrix multiplication by version 1.10 with PEP 465.
For rust, it is virtually hopeless to use the same @
operator as matrix multiplication, which is fully discussed in Rust internal forum (@
has been used as binary operator for pattern binding).
To the RSTSR developer's perspective, this is very unfortunate.
Also, other kind of operators (such as %*%
for R, .*
for Matlab and Julia, .
for Mathematica) simply don't exist as binary operator in rust's language.
If we wish to that kind of notations, it requires support from programming language level, and this kind of features are not promised to be stablized soon.
However, we consider that though %
has been commonly used as remainder, it is less used in vector or matrix computation.
%
also shares the same operator priority with *
and /
.
Thus, we decided to apply %
as matrix multiplication notation if proper.
We reserve name of function rem
for remainder computation, and name matmul
for matrix multiplication.
#![allow(unused)] fn main() { let a = rt::arange(6); // remainder to scalar let c = rt::rem(&a, 3); println!("{:}", c); // output: [ 0 1 2 0 1 2] // remainder to array let b = rt::asarray(vec![3, 2, 3, 3, 2, 2]); let c = rt::rem(&a, &b); println!("{:}", c); // output: [ 0 1 2 0 0 1] }
Do not use rem
as associated (struct member) function
We have shown that rt::rem
is a valid function for evaluating tensor remainder:
#![allow(unused)] fn main() { let a = rt::arange(6); let b = rt::asarray(vec![3, 2, 3, 3, 2, 2]); // remainder to array let c = rt::rem(&a, &b); println!("{:}", c); // output: [ 0 1 2 0 0 1] }
However, function tensor.rem(other)
is not rt::rem
by definition.
It is defined as rust's associated function, by trait core::ops::Rem
.
Since we overrided this trait by matmul, tensor.rem(other)
will also call matmul operation.
#![allow(unused)] fn main() { // inner product (due to override to `Rem`) let c = a.view().rem(&b); println!("{:}", c); // output: 35 }
Since this kind of code will cause confusion, we advice API users not using rem
as associated function.
3. Broadcasting
Broadcasting makes many tensor operations very simple. RSTSR applies most broadcasting rules from numpy or Python Array API. We refer interested users to numpy and Python Array API documents.
RSTSR initial developer is a computational chemist. We will use an example in chemistry programming, to show how to use broadcasting in real-world situations.
3.1 Example of elementwise multiplication
Sum-of-exponent approximation to RI-MP2 (resolution-identity Moller-Plesset second order perturbation), also termed as LT-OS-MP2, involves the following computation:
#![allow(unused)] fn main() { // task definition let (naux, nocc, nvir) = (8, 2, 4); // subscripts (P, i, a) let y = rt::arange(naux * nocc * nvir).into_shape([naux, nocc, nvir]); let ei = rt::arange(nocc); let ea = rt::arange(nvir); }
This is elementwise multiplication of 3-D tensor with 1-D tensors. In usual cases, the 1-D tensors and should be expanded and repeated to 3-D counterpart and , then perform multiplication This is both inconvenient and inefficient. By broadcasting, we can insert axis to 1-D tensors, without repeating values:
#![allow(unused)] fn main() { // elementwise multiplication with broadcasting // `None` means inserting axis, equivalent to `np.newaxis` in NumPy or `NewAxis` in RSTSR let converted_y = &y * ei.slice((None, .., None)) * ea.slice((None, None, ..)); }
This multiplication can still be simplified. By numpy's definition of broadcasting rule, it will always add ellipsis at the first dimension. So any operation that inserts axis at the first dimension can be removed:
#![allow(unused)] fn main() { // elementwise multiplication with simplified broadcasting let converted_y = &y * &ei.slice((.., None)) * &ea; }
Finally, for memory and efficiency concern, it is preferred to perform tensor elementwise multiplication of first:
#![allow(unused)] fn main() { // optimize for memory access cost let converted_y = &y * (&ei.slice((.., None)) * &ea); }
3.2 Example of matrix multiplication
Many post-HF methods involve integral basis transformation, mostly from raw basis (atomic basis or denoted AO, for example) to molecular orbital basis (denoted MO): This operation involves five indices, , where number of indices are smaller than .
#![allow(unused)] fn main() { // task definition let (naux, nocc, nvir, nao, _) = (8, 2, 4, 6, 6); // subscripts (P, i, a, μ, ν) let y_ao = rt::arange(naux * nao * nao).into_shape([naux, nao, nao]); let c_occ = rt::arange(nao * nocc).into_shape([nao, nocc]); let c_vir = rt::arange(nao * nvir).into_shape([nao, nvir]); }
The broadcasting rule is slightly complicated for matrix multiplication. However, if you are familiar to broadcasting rule, this task can be realized with very simple code:
#![allow(unused)] fn main() { let y_mo = &c_occ.t() % &y_ao % &c_vir; println!("{:?}", y_mo.layout()); }
This operation can be further optimized in efficiency.
This code is simple and elegant. It will properly handle multi-threading in devices with rayon support.
However, it requires multiple times of accessing 3-D tensors, and will generate a temporary 3-D tensor. This is both inefficient in memory access and memory cost.
To resolve memory inefficiency problem, this computation can be performed with parallel axis iterator. However, RSTSR has not finished this part currently. We will touch this topic in a later time.
4. Memory Aspects
This is related to how value is passed to arithmetic operations.
4.1 Computation by arithmetic operators
In rust, variable ownership and lifetime rule is strict. The following code will give compiler error:
#![allow(unused)] fn main() { let a = rt::arange(5.0); let b = rt::arange(5.0) + 1.0; let c = a + b; let d = a * b; }
| let c = a + b;
| - value moved here
| let d = a * b;
| ^ value used here after move
|
help: consider cloning the value if the performance cost is acceptable
|
| let c = a + b.clone();
| ++++++++
However, in many cases, performance and memory cost of cloning the tensor is not acceptable. So it is more preferred to perform computation by the following ways, to avoid memory copy and lifetime limitations:
- use reference of tensor,
- use view of tensor,
- use clone of view of tensor,
#![allow(unused)] fn main() { // arithmetic by reference let c = &a + &b; // arithmetic by view let d = a.view() * b.view(); // view clone is cheap, given tensor is large let a_view = a.view(); let b_view = b.view(); let e = a_view.clone() * b_view.clone(); }
It should be noted that, except for lifetime limitation, owned tensor is still able to be passed to arithmetic operations.
Moreover, inplace arithmetics will be applied when possible (type constraint and broadcastability).
For example of 1-D tensor addition, memory of variable c
is not allocated, but instead reused from variable a
.
So if you are sure that a
will not be used anymore, you can pass a
by value, and that will be more efficient.
#![allow(unused)] fn main() { let a = rt::arange(5.0); let b = rt::arange(5.0) + 1.0; let ptr_a = a.rawvec().as_ptr(); // if sure that `a` is not used anymore, pass `a` by value instead of reference let c = a + &b; let ptr_c = c.rawvec().as_ptr(); // raw data of `a` is reused in `c` // similar to `a += &b; let c = a;` assert_eq!(ptr_a, ptr_c); }
4.2 Computation by associated functions
In RSTSR, there are three ways to perform arithmetic operations:
- by operator:
&a + &b
; - by function:
rt::add(&a, &b)
; - by associated function:
(&a).add(&b)
ora.view().add(&b)
.
You may found that code of usage of associated function is somehow weird.
In fact, a.add(&b)
is also valid in rust, but this will consumes variable a
.
The following code will not compile due to this problem:
let a = rt::arange(5.0);
let b = rt::arange(5.0) + 1.0;
// below is valid, however `a` is moved
let c = a.add(&b);
// below is invalid
let d = a.div(&b);
// ^ value used here after move
// note: `std::ops::Add::add` takes ownership of the receiver `self`, which moves `a`