#ifndef MAT_HXX
#define MAT_HXX
#include "vect.hxx"
#include "vec.hxx"
#include <cctk.h>
#include <algorithm>
#include <array>
#include <cassert>
#include <functional>
#include <initializer_list>
#include <iostream>
#include <tuple>
#include <utility>
#include <vector>
#ifdef CCTK_DEBUG
#define ARITH_INLINE
#else
#define ARITH_INLINE CCTK_ATTRIBUTE_ALWAYS_INLINE
#endif
namespace Arith {
template <typename T, int D, dnup_t dnup1, dnup_t dnup2> class mat {
static_assert(dnup1 == dnup2, "");
template <typename, int, dnup_t, dnup_t> friend class mat;
constexpr static int N = D * (D - 1);
vect<T, N> elts;
static constexpr ARITH_INLINE int symind(const int i, const int j) {
#ifdef CCTK_DEBUG
assert(i >= 0 && i <= j && j < 3);
#endif
const int n = 3 * i - i * (i + 1) / 2 + j;
#ifdef CCTK_DEBUG
assert(n >= 0 && n < N);
#endif
return n;
}
static constexpr ARITH_INLINE int ind(const int i, const int j) {
using std::max, std::min;
return symind(min(i, j), max(i, j));
}
static_assert(symind(0, 0) == 0, "");
static_assert(symind(0, 1) == 1, "");
static_assert(symind(0, 2) == 2, "");
static_assert(symind(1, 1) == 3, "");
static_assert(symind(1, 2) == 4, "");
static_assert(symind(2, 2) == 5, "");
static_assert(ind(1, 0) == ind(0, 1), "");
static_assert(ind(2, 0) == ind(0, 2), "");
static_assert(ind(2, 1) == ind(1, 2), "");
public:
explicit constexpr ARITH_INLINE mat() : elts() {}
constexpr ARITH_INLINE mat(const mat &) = default;
constexpr ARITH_INLINE mat(mat &&) = default;
constexpr ARITH_INLINE mat &operator=(const mat &) = default;
constexpr ARITH_INLINE mat &operator=(mat &&) = default;
template <typename U>
constexpr ARITH_INLINE mat(const mat<U, D, dnup1, dnup2> &x) : elts(x.elts) {}
template <typename U>
constexpr ARITH_INLINE mat(mat<U, D, dnup1, dnup2> &&x)
: elts(move(x.elts)) {}
constexpr ARITH_INLINE mat(const vect<T, N> &elts) : elts(elts) {}
constexpr ARITH_INLINE mat(vect<T, N> &&elts) : elts(move(elts)) {}
constexpr ARITH_INLINE mat(initializer_list<T> A) : elts(A) {}
constexpr ARITH_INLINE mat(const vector<T> &A) : elts(A) {}
constexpr ARITH_INLINE mat(vector<T> &&A) : elts(move(A)) {}
template <typename F, typename = result_of_t<F(int, int)> >
constexpr ARITH_INLINE mat(F f) : mat(iota1().map(f, iota2())) {
}
static constexpr ARITH_INLINE mat unit(int i, int j) {
mat r;
r(i, j) = 1;
return r;
}
static constexpr ARITH_INLINE mat<array<int, 2>, D, dnup1, dnup2> iota() {
mat<array<int, 2>, D, dnup1, dnup2> r;
for (int i = 0; i < D; ++i)
for (int j = i; j < D; ++j)
r(i, j) = {i, j};
return r;
}
static constexpr ARITH_INLINE mat<int, D, dnup1, dnup2> iota1() {
mat<int, D, dnup1, dnup2> r;
for (int i = 0; i < D; ++i)
for (int j = i; j < D; ++j)
r(i, j) = i;
return r;
}
static constexpr ARITH_INLINE mat<int, D, dnup1, dnup2> iota2() {
mat<int, D, dnup1, dnup2> r;
for (int i = 0; i < D; ++i)
for (int j = i; j < D; ++j)
r(i, j) = j;
return r;
}
template <typename F,
typename R = remove_cv_t<remove_reference_t<result_of_t<F(T)> > > >
constexpr ARITH_INLINE mat<R, D, dnup1, dnup2> map(F f) const {
mat<R, D, dnup1, dnup2> r;
for (int i = 0; i < N; ++i)
r.elts[i] = f(elts[i]);
return r;
}
template <
typename F, typename U,
typename R = remove_cv_t<remove_reference_t<result_of_t<F(T, U)> > > >
constexpr ARITH_INLINE mat<R, D, dnup1, dnup2>
map(F f, const mat<U, D, dnup1, dnup2> &x) const {
mat<R, D, dnup1, dnup2> r;
for (int i = 0; i < N; ++i)
r.elts[i] = f(elts[i], x.elts[i]);
return r;
}
constexpr ARITH_INLINE const T &operator()(int i, int j) const {
return elts[ind(i, j)];
}
constexpr ARITH_INLINE T &operator()(int i, int j) { return elts[ind(i, j)]; }
friend constexpr ARITH_INLINE mat<T, D, dnup1, dnup2>
operator+(const mat<T, D, dnup1, dnup2> &x) {
return {+x.elts};
}
friend constexpr ARITH_INLINE mat<T, D, dnup1, dnup2>
operator-(const mat<T, D, dnup1, dnup2> &x) {
return {-x.elts};
}
friend constexpr ARITH_INLINE mat<T, D, dnup1, dnup2>
operator+(const mat<T, D, dnup1, dnup2> &x,
const mat<T, D, dnup1, dnup2> &y) {
return {x.elts + y.elts};
}
friend constexpr ARITH_INLINE mat<T, D, dnup1, dnup2>
operator-(const mat<T, D, dnup1, dnup2> &x,
const mat<T, D, dnup1, dnup2> &y) {
return {x.elts - y.elts};
}
friend constexpr ARITH_INLINE mat<T, D, dnup1, dnup2>
operator*(const T &a, const mat<T, D, dnup1, dnup2> &x) {
return {a * x.elts};
}
friend constexpr ARITH_INLINE mat<T, D, dnup1, dnup2>
operator*(const mat<T, D, dnup1, dnup2> &x, const T &a) {
return {x.elts * a};
}
constexpr ARITH_INLINE mat operator+=(const mat &x) {
return *this = *this + x;
}
constexpr ARITH_INLINE mat operator-=(const mat &x) {
return *this = *this - x;
}
constexpr ARITH_INLINE mat operator*=(const T &a) {
return *this = *this * a;
}
constexpr ARITH_INLINE mat operator/=(const T &a) {
return *this = *this / a;
}
friend constexpr ARITH_INLINE bool
operator==(const mat<T, D, dnup1, dnup2> &x,
const mat<T, D, dnup1, dnup2> &y) {
return equal_to<vect<T, N> >()(x.elts, y.elts);
}
friend constexpr ARITH_INLINE bool
operator!=(const mat<T, D, dnup1, dnup2> &x,
const mat<T, D, dnup1, dnup2> &y) {
return !(x == y);
}
constexpr ARITH_INLINE T maxabs() const { return elts.maxabs(); }
friend ostream &operator<<(ostream &os, const mat<T, D, dnup1, dnup2> &A) {
os << "(" << dnup1 << dnup2 << ")[";
for (int j = 0; j < D; ++j) {
if (j > 0)
os << ",";
os << "[";
for (int i = 0; i < D; ++i) {
if (i > 0)
os << ",";
os << A(i, j);
}
os << "]";
}
os << "]";
return os;
}
};
}
#endif