Skip to content

Commit b2b2ad8

Browse files
authored
Merge pull request #3059 from stan-dev/arm64-tests
Fixes for ARM64
2 parents 9202f1f + 7326216 commit b2b2ad8

File tree

8 files changed

+43
-18
lines changed

8 files changed

+43
-18
lines changed

make/compiler_flags

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,18 @@ endif
1717

1818
## Set OS specific library filename extensions
1919
ifeq ($(OS),Windows_NT)
20-
WINARM64 := $(shell echo | $(CXX) -E -dM - | findstr __aarch64__)
2120
LIBRARY_SUFFIX ?= .dll
21+
STR_SEARCH ?= findstr
2222
endif
2323

2424
ifeq ($(OS),Darwin)
2525
LIBRARY_SUFFIX ?= .dylib
26+
STR_SEARCH ?= grep
2627
endif
2728

2829
ifeq ($(OS),Linux)
2930
LIBRARY_SUFFIX ?= .so
31+
STR_SEARCH ?= grep
3032
endif
3133

3234
## Set default compiler
@@ -42,6 +44,11 @@ ifeq (default,$(origin CXX))
4244
endif
4345
endif
4446

47+
ARM64_CHECK := $(shell echo | $(CXX) -E -dM - | $(STR_SEARCH) __aarch64__)
48+
ifneq ($(ARM64_CHECK),)
49+
ARM64 = true
50+
endif
51+
4552
# Detect compiler type
4653
# - CXX_TYPE: {gcc, clang, mingw32-gcc, other}
4754
# - CXX_MAJOR: major version of CXX
@@ -164,7 +171,7 @@ ifeq ($(OS),Windows_NT)
164171

165172
make/ucrt:
166173
pound := \#
167-
UCRT_STRING := $(shell echo '$(pound)include <windows.h>' | $(CXX) -E -dM - | findstr _UCRT)
174+
UCRT_STRING := $(shell echo '$(pound)include <windows.h>' | $(CXX) -E -dM - | $(STR_SEARCH) _UCRT)
168175
ifneq (,$(UCRT_STRING))
169176
IS_UCRT ?= true
170177
else
@@ -211,6 +218,10 @@ endif
211218
## makes reentrant version lgamma_r available from cmath
212219
CXXFLAGS_OS += -D_REENTRANT
213220

221+
ifeq ($(ARM64), true)
222+
CXXFLAGS_OS += -ffp-contract=off
223+
endif
224+
214225
## silence warnings occuring due to the TBB and Eigen libraries
215226
CXXFLAGS_WARNINGS += -Wno-ignored-attributes
216227

@@ -275,7 +286,7 @@ endif
275286
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_LIB)" -Wl,--disable-new-dtags
276287

277288
# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
278-
ifeq ($(WINARM64),)
289+
ifneq ($(OS), Windows_NT)
279290
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_LIB)"
280291
endif
281292

@@ -299,7 +310,7 @@ CXXFLAGS_TBB ?= -I $(TBB)/include
299310
LDFLAGS_TBB ?= -Wl,-L,"$(TBB_BIN_ABSOLUTE_PATH)" $(LDFLAGS_FLTO_FLTO) $(LDFLAGS_OPTIM_TBB)
300311

301312
# Windows LLVM/Clang does not support -rpath, but is not needed on Windows anyway
302-
ifeq ($(WINARM64),)
313+
ifneq ($(OS), Windows_NT)
303314
LDFLAGS_TBB += -Wl,-rpath,"$(TBB_BIN_ABSOLUTE_PATH)"
304315
endif
305316
LDLIBS_TBB ?= -ltbb

make/libraries

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,11 @@ ifeq (Windows_NT, $(OS))
141141
TBB_CXXFLAGS += -D_UCRT
142142
endif
143143
# TBB does not have assembly code for Windows ARM64, so we need to use GCC builtins
144-
ifneq ($(WINARM64),)
145-
TBB_CXXFLAGS += -DTBB_USE_GCC_BUILTINS
146-
CXXFLAGS_TBB += -DTBB_USE_GCC_BUILTINS
147-
endif
144+
ifeq ($(ARM64),true)
145+
TBB_CXXFLAGS += -DTBB_USE_GCC_BUILTINS
146+
CXXFLAGS_TBB += -DTBB_USE_GCC_BUILTINS
147+
WINARM64 = true
148+
endif
148149
SH_CHECK := $(shell command -v sh 2>/dev/null)
149150
ifdef SH_CHECK
150151
WINDOWS_HAS_SH ?= true

stan/math/prim/fun/inv_sqrt.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,14 @@ inline auto inv_sqrt(const Container& x) {
6060
template <typename Container, require_not_var_matrix_t<Container>* = nullptr,
6161
require_container_st<std::is_arithmetic, Container>* = nullptr>
6262
inline auto inv_sqrt(const Container& x) {
63+
// Eigen 3.4.0 has precision issues on ARM64 with vectorised rsqrt
64+
// Resolved in current master branch, below can be removed on next release
65+
#ifdef __aarch64__
66+
return apply_scalar_unary<inv_sqrt_fun, Container>::apply(x);
67+
#else
6368
return apply_vector_unary<Container>::apply(
6469
x, [](const auto& v) { return v.array().rsqrt(); });
70+
#endif
6571
}
6672

6773
} // namespace math

test/unit/math/fwd/core/std_numeric_limits_test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,10 @@ TEST(AgradFwdNumericLimits, All_Fvar) {
103103
EXPECT_FALSE(std::numeric_limits<fvar<double> >::traps);
104104
EXPECT_FALSE(std::numeric_limits<fvar<fvar<double> > >::traps);
105105

106-
EXPECT_FALSE(std::numeric_limits<fvar<double> >::tinyness_before);
107-
EXPECT_FALSE(std::numeric_limits<fvar<fvar<double> > >::tinyness_before);
106+
EXPECT_EQ(std::numeric_limits<fvar<double> >::tinyness_before,
107+
std::numeric_limits<double>::tinyness_before);
108+
EXPECT_EQ(std::numeric_limits<fvar<fvar<double> > >::tinyness_before,
109+
std::numeric_limits<double>::tinyness_before);
108110

109111
EXPECT_TRUE(std::numeric_limits<fvar<double> >::round_style);
110112
EXPECT_TRUE(std::numeric_limits<fvar<fvar<double> > >::round_style);

test/unit/math/mix/core/std_numeric_limits_test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ TEST(AgradMixNumericLimits, All_Fvar) {
108108
EXPECT_FALSE(std::numeric_limits<fvar<var> >::traps);
109109
EXPECT_FALSE(std::numeric_limits<fvar<fvar<var> > >::traps);
110110

111-
EXPECT_FALSE(std::numeric_limits<fvar<var> >::tinyness_before);
112-
EXPECT_FALSE(std::numeric_limits<fvar<fvar<var> > >::tinyness_before);
111+
EXPECT_EQ(std::numeric_limits<fvar<var> >::tinyness_before,
112+
std::numeric_limits<double>::tinyness_before);
113+
EXPECT_EQ(std::numeric_limits<fvar<fvar<var> > >::tinyness_before,
114+
std::numeric_limits<double>::tinyness_before);
113115

114116
EXPECT_TRUE(std::numeric_limits<fvar<var> >::round_style);
115117
EXPECT_TRUE(std::numeric_limits<fvar<fvar<var> > >::round_style);

test/unit/math/prim/fun/offset_multiplier_transform_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ TEST(prob_transform, offset_multiplier_constrain_matrix) {
193193
EXPECT_FLOAT_EQ(result(i), stan::math::offset_multiplier_constrain(
194194
x(i), offsetd, sigma(i), lp1));
195195
}
196-
EXPECT_EQ(lp0, lp1);
196+
EXPECT_FLOAT_EQ(lp0, lp1);
197197
auto x_free = stan::math::offset_multiplier_free(result, offsetd, sigma);
198198
for (size_t i = 0; i < x.size(); ++i) {
199199
EXPECT_FLOAT_EQ(x.coeff(i), x_free.coeff(i));
@@ -211,7 +211,7 @@ TEST(prob_transform, offset_multiplier_constrain_matrix) {
211211
EXPECT_FLOAT_EQ(result(i), stan::math::offset_multiplier_constrain(
212212
x(i), offset(i), sigma(i), lp1));
213213
}
214-
EXPECT_EQ(lp0, lp1);
214+
EXPECT_FLOAT_EQ(lp0, lp1);
215215
auto x_free = stan::math::offset_multiplier_free(result, offset, sigma);
216216
for (size_t i = 0; i < x.size(); ++i) {
217217
EXPECT_FLOAT_EQ(x.coeff(i), x_free.coeff(i));

test/unit/math/prim/prob/neg_binomial_test.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,11 @@ TEST(ProbDistributionsNegBinomial, chiSquareGoodnessFitTest3) {
189189

190190
double chi = 0;
191191

192-
for (int j = 0; j < K; j++)
193-
chi += ((bin[j] - expect[j]) * (bin[j] - expect[j]) / expect[j]);
192+
for (int j = 0; j < K; j++) {
193+
if (expect[j] != 0) {
194+
chi += ((bin[j] - expect[j]) * (bin[j] - expect[j]) / expect[j]);
195+
}
196+
}
194197

195198
EXPECT_LT(chi, boost::math::quantile(boost::math::complement(mydist, 1e-6)));
196199
}

test/unit/math/test_ad.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,7 +1991,7 @@ void expect_common_unary_vectorized(const F& f) {
19911991
for (double x1 : args)
19921992
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
19931993
auto int_args = internal::common_int_args();
1994-
for (int x1 : args)
1994+
for (int x1 : int_args)
19951995
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
19961996
}
19971997

@@ -2022,7 +2022,7 @@ void expect_common_unary_vectorized(const F& f) {
20222022
for (double x1 : args)
20232023
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
20242024
auto int_args = internal::common_int_args();
2025-
for (int x1 : args)
2025+
for (int x1 : int_args)
20262026
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);
20272027
for (auto x1 : common_complex())
20282028
stan::test::expect_ad_vectorized<ComplexSupport>(tols, f, x1);

0 commit comments

Comments
 (0)