You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
172 lines
5.0 KiB
172 lines
5.0 KiB
// This file is part of Eigen, a lightweight C++ template library |
|
// for linear algebra. |
|
// |
|
// Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com> |
|
// |
|
// This Source Code Form is subject to the terms of the Mozilla |
|
// Public License v. 2.0. If a copy of the MPL was not distributed |
|
// with this file, You can obtain one at the mozilla.org home page |
|
|
|
#include "main.h" |
|
|
|
#include <Eigen/CXX11/Tensor> |
|
|
|
using Eigen::Tensor; |
|
using Eigen::array; |
|
|
|
template <int DataLayout> |
|
static void test_0D_trace() { |
|
Tensor<float, 0, DataLayout> tensor; |
|
tensor.setRandom(); |
|
array<ptrdiff_t, 0> dims; |
|
Tensor<float, 0, DataLayout> result = tensor.trace(dims); |
|
VERIFY_IS_EQUAL(result(), tensor()); |
|
} |
|
|
|
|
|
template <int DataLayout> |
|
static void test_all_dimensions_trace() { |
|
Tensor<float, 3, DataLayout> tensor1(5, 5, 5); |
|
tensor1.setRandom(); |
|
Tensor<float, 0, DataLayout> result1 = tensor1.trace(); |
|
VERIFY_IS_EQUAL(result1.rank(), 0); |
|
float sum = 0.0f; |
|
for (int i = 0; i < 5; ++i) { |
|
sum += tensor1(i, i, i); |
|
} |
|
VERIFY_IS_EQUAL(result1(), sum); |
|
|
|
Tensor<float, 5, DataLayout> tensor2(7, 7, 7, 7, 7); |
|
tensor2.setRandom(); |
|
array<ptrdiff_t, 5> dims = { { 2, 1, 0, 3, 4 } }; |
|
Tensor<float, 0, DataLayout> result2 = tensor2.trace(dims); |
|
VERIFY_IS_EQUAL(result2.rank(), 0); |
|
sum = 0.0f; |
|
for (int i = 0; i < 7; ++i) { |
|
sum += tensor2(i, i, i, i, i); |
|
} |
|
VERIFY_IS_EQUAL(result2(), sum); |
|
} |
|
|
|
|
|
template <int DataLayout> |
|
static void test_simple_trace() { |
|
Tensor<float, 3, DataLayout> tensor1(3, 5, 3); |
|
tensor1.setRandom(); |
|
array<ptrdiff_t, 2> dims1 = { { 0, 2 } }; |
|
Tensor<float, 1, DataLayout> result1 = tensor1.trace(dims1); |
|
VERIFY_IS_EQUAL(result1.rank(), 1); |
|
VERIFY_IS_EQUAL(result1.dimension(0), 5); |
|
float sum = 0.0f; |
|
for (int i = 0; i < 5; ++i) { |
|
sum = 0.0f; |
|
for (int j = 0; j < 3; ++j) { |
|
sum += tensor1(j, i, j); |
|
} |
|
VERIFY_IS_EQUAL(result1(i), sum); |
|
} |
|
|
|
Tensor<float, 4, DataLayout> tensor2(5, 5, 7, 7); |
|
tensor2.setRandom(); |
|
array<ptrdiff_t, 2> dims2 = { { 2, 3 } }; |
|
Tensor<float, 2, DataLayout> result2 = tensor2.trace(dims2); |
|
VERIFY_IS_EQUAL(result2.rank(), 2); |
|
VERIFY_IS_EQUAL(result2.dimension(0), 5); |
|
VERIFY_IS_EQUAL(result2.dimension(1), 5); |
|
for (int i = 0; i < 5; ++i) { |
|
for (int j = 0; j < 5; ++j) { |
|
sum = 0.0f; |
|
for (int k = 0; k < 7; ++k) { |
|
sum += tensor2(i, j, k, k); |
|
} |
|
VERIFY_IS_EQUAL(result2(i, j), sum); |
|
} |
|
} |
|
|
|
array<ptrdiff_t, 2> dims3 = { { 1, 0 } }; |
|
Tensor<float, 2, DataLayout> result3 = tensor2.trace(dims3); |
|
VERIFY_IS_EQUAL(result3.rank(), 2); |
|
VERIFY_IS_EQUAL(result3.dimension(0), 7); |
|
VERIFY_IS_EQUAL(result3.dimension(1), 7); |
|
for (int i = 0; i < 7; ++i) { |
|
for (int j = 0; j < 7; ++j) { |
|
sum = 0.0f; |
|
for (int k = 0; k < 5; ++k) { |
|
sum += tensor2(k, k, i, j); |
|
} |
|
VERIFY_IS_EQUAL(result3(i, j), sum); |
|
} |
|
} |
|
|
|
Tensor<float, 5, DataLayout> tensor3(3, 7, 3, 7, 3); |
|
tensor3.setRandom(); |
|
array<ptrdiff_t, 3> dims4 = { { 0, 2, 4 } }; |
|
Tensor<float, 2, DataLayout> result4 = tensor3.trace(dims4); |
|
VERIFY_IS_EQUAL(result4.rank(), 2); |
|
VERIFY_IS_EQUAL(result4.dimension(0), 7); |
|
VERIFY_IS_EQUAL(result4.dimension(1), 7); |
|
for (int i = 0; i < 7; ++i) { |
|
for (int j = 0; j < 7; ++j) { |
|
sum = 0.0f; |
|
for (int k = 0; k < 3; ++k) { |
|
sum += tensor3(k, i, k, j, k); |
|
} |
|
VERIFY_IS_EQUAL(result4(i, j), sum); |
|
} |
|
} |
|
|
|
Tensor<float, 5, DataLayout> tensor4(3, 7, 4, 7, 5); |
|
tensor4.setRandom(); |
|
array<ptrdiff_t, 2> dims5 = { { 1, 3 } }; |
|
Tensor<float, 3, DataLayout> result5 = tensor4.trace(dims5); |
|
VERIFY_IS_EQUAL(result5.rank(), 3); |
|
VERIFY_IS_EQUAL(result5.dimension(0), 3); |
|
VERIFY_IS_EQUAL(result5.dimension(1), 4); |
|
VERIFY_IS_EQUAL(result5.dimension(2), 5); |
|
for (int i = 0; i < 3; ++i) { |
|
for (int j = 0; j < 4; ++j) { |
|
for (int k = 0; k < 5; ++k) { |
|
sum = 0.0f; |
|
for (int l = 0; l < 7; ++l) { |
|
sum += tensor4(i, l, j, l, k); |
|
} |
|
VERIFY_IS_EQUAL(result5(i, j, k), sum); |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
template<int DataLayout> |
|
static void test_trace_in_expr() { |
|
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 3); |
|
tensor.setRandom(); |
|
array<ptrdiff_t, 2> dims = { { 1, 3 } }; |
|
Tensor<float, 2, DataLayout> result(2, 5); |
|
result = result.constant(1.0f) - tensor.trace(dims); |
|
VERIFY_IS_EQUAL(result.rank(), 2); |
|
VERIFY_IS_EQUAL(result.dimension(0), 2); |
|
VERIFY_IS_EQUAL(result.dimension(1), 5); |
|
float sum = 0.0f; |
|
for (int i = 0; i < 2; ++i) { |
|
for (int j = 0; j < 5; ++j) { |
|
sum = 0.0f; |
|
for (int k = 0; k < 3; ++k) { |
|
sum += tensor(i, k, j, k); |
|
} |
|
VERIFY_IS_EQUAL(result(i, j), 1.0f - sum); |
|
} |
|
} |
|
} |
|
|
|
|
|
EIGEN_DECLARE_TEST(cxx11_tensor_trace) { |
|
CALL_SUBTEST(test_0D_trace<ColMajor>()); |
|
CALL_SUBTEST(test_0D_trace<RowMajor>()); |
|
CALL_SUBTEST(test_all_dimensions_trace<ColMajor>()); |
|
CALL_SUBTEST(test_all_dimensions_trace<RowMajor>()); |
|
CALL_SUBTEST(test_simple_trace<ColMajor>()); |
|
CALL_SUBTEST(test_simple_trace<RowMajor>()); |
|
CALL_SUBTEST(test_trace_in_expr<ColMajor>()); |
|
CALL_SUBTEST(test_trace_in_expr<RowMajor>()); |
|
}
|
|
|