Skip to content

Commit 12f832f

Browse files
committed
Added expression functions
I'm not sure that I'm 100% happy with this design, but I'm not sure that I could really do any better. Previously, if you wanted to call frame::prepend_column(), frame::append_column(), frame::rows(), or frame::operator[] with an expression, you'd be limited to statements like fr[ _1 > 20 ]; If you wanted to actually test the rounded value of _1, OTOH, then you'd have to use a lambda: fr.rows([]( const auto& b, const auto& c, const auto& e) { return ::round( c.at( _1 ) ) > 20; }); Nothing wrong with that, but it's no longer a 1-liner. With this new functionality, you can now call fr[ fn( ::round, _1 ) > 20 ]; For now it's super-fragile. For one thing, it really can't deal with overloaded functions. And the error messages that it returns in that case are really messy. I'd like to figure out a static assert for that case. Also, I haven't added a static assert for the case where the user hasn't provided enough or has provided too many parameters for the function. But it does work.
1 parent 98280e3 commit 12f832f

File tree

5 files changed

+204
-16
lines changed

5 files changed

+204
-16
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ else()
2727
add_compile_options(-Wall -Wextra -pedantic -Werror)
2828
endif()
2929

30-
set( version 0.7.0 )
30+
set( version 0.8.0 )
3131

3232
# mainframe ==================================================================
3333

mainframe/detail/base.hpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,21 @@ struct is_equality_comparable<T, U, std::void_t<equality_comparison_t<T, U>>>
306306
: std::is_same<equality_comparison_t<T, U>, bool>
307307
{};
308308

309+
template<typename Func>
310+
struct get_return_type;
311+
312+
template<typename Ret, typename... Args>
313+
struct get_return_type<Ret(Args...)>
314+
{
315+
using type = Ret;
316+
};
317+
318+
template<typename Ret, typename... Args>
319+
struct get_return_type<Ret(Args...) noexcept>
320+
{
321+
using type = Ret;
322+
};
323+
309324
} // namespace detail
310325

311326
} // namespace mf

mainframe/detail/expression.hpp

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ struct unary_expr;
175175
template<typename Op, typename L, typename R>
176176
struct binary_expr;
177177

178+
template<typename Func, typename... As>
179+
struct func_expr;
180+
178181
template<typename T>
179182
struct is_complex_expression : std::false_type
180183
{};
@@ -187,6 +190,10 @@ template<typename Op, typename L, typename R>
187190
struct is_complex_expression<binary_expr<Op, L, R>> : std::true_type
188191
{};
189192

193+
template<typename Func, typename... As>
194+
struct is_complex_expression<func_expr<Func, As...>> : std::true_type
195+
{};
196+
190197
template<size_t Ind>
191198
struct indexed_expr_column;
192199

@@ -405,6 +412,34 @@ struct terminal<frame_length>
405412
}
406413
};
407414

415+
// Op is a member type in expr_op
416+
// T must be either terminal<>, unary_expr<>, or binary_expr<> (no
417+
// unwrapped types - that's why make_unary_expr exists)
418+
template<typename Op, typename T>
419+
struct unary_expr
420+
{
421+
using is_expr = void;
422+
static_assert(is_expression<T>::value, "unary expression must contain expression");
423+
424+
explicit unary_expr(T _t)
425+
: t(_t)
426+
{}
427+
428+
template<template<bool, bool, typename...> typename Iter, bool IsConst, bool IsReverse,
429+
typename... Ts>
430+
auto
431+
operator()(const Iter<IsConst, IsReverse, Ts...>& begin,
432+
const Iter<IsConst, IsReverse, Ts...>& curr,
433+
const Iter<IsConst, IsReverse, Ts...>& end) const
434+
-> decltype(Op::exec(std::declval<T&>().
435+
operator()(curr)))
436+
{
437+
return Op::exec(t.operator()(begin, curr, end));
438+
}
439+
440+
T t;
441+
};
442+
408443
template<typename Op, typename L, typename R>
409444
struct binary_expr
410445
{
@@ -435,32 +470,46 @@ struct binary_expr
435470
R r;
436471
};
437472

438-
// Op is a member type in expr_op
439-
// T must be either terminal<>, unary_expr<>, or binary_expr<> (no
440-
// unwrapped types - that's why make_unary_expr exists)
441-
template<typename Op, typename T>
442-
struct unary_expr
473+
template<typename Func, typename... As>
474+
struct func_expr
443475
{
476+
func_expr() = default;
477+
func_expr(Func& f, As... a) : func(&f), args(a...) {}
478+
444479
using is_expr = void;
445-
static_assert(is_expression<T>::value, "unary expression must contain expression");
446480

447-
explicit unary_expr(T _t)
448-
: t(_t)
449-
{}
481+
template<typename Iter>
482+
using applied_args = std::tuple<decltype(std::declval<As&>()(
483+
std::declval<Iter&>(), std::declval<Iter&>(), std::declval<Iter&>())) ... >;
484+
485+
using return_type = typename detail::get_return_type<Func>::type;
450486

451487
template<template<bool, bool, typename...> typename Iter, bool IsConst, bool IsReverse,
452488
typename... Ts>
453-
auto
489+
return_type
454490
operator()(const Iter<IsConst, IsReverse, Ts...>& begin,
455491
const Iter<IsConst, IsReverse, Ts...>& curr,
456492
const Iter<IsConst, IsReverse, Ts...>& end) const
457-
-> decltype(Op::exec(std::declval<T&>().
458-
operator()(curr)))
459493
{
460-
return Op::exec(t.operator()(begin, curr, end));
494+
applied_args<Iter<IsConst, IsReverse, Ts...> > results;
495+
get_val<0>( begin, curr, end, results );
496+
return std::apply( func, results );
461497
}
462498

463-
T t;
499+
template< size_t Ind, template<bool, bool, typename...> typename Iter, bool IsConst,
500+
bool IsReverse, typename... Ts >
501+
void get_val( const Iter<IsConst, IsReverse, Ts...>& begin,
502+
const Iter<IsConst, IsReverse, Ts...>& curr, const Iter<IsConst, IsReverse, Ts...>& end,
503+
applied_args<Iter<IsConst, IsReverse, Ts...>> & results ) const
504+
{
505+
std::get<Ind>(results) = std::get<Ind>(args)(begin, curr, end);
506+
if constexpr (Ind+1 < sizeof...(As)) {
507+
get_val<Ind+1>(begin, curr, end, results);
508+
}
509+
}
510+
511+
Func* func = nullptr;
512+
std::tuple<As...> args;
464513
};
465514

466515
// If T is terminal<U>, unary_expr<Op,U> or binary_expr<Op,L,R>, just return T.
@@ -514,6 +563,18 @@ struct maybe_wrap<binary_expr<Op, L, R>>
514563
}
515564
};
516565

566+
template<typename Func, typename... As>
567+
struct maybe_wrap<func_expr<Func, As...>>
568+
{
569+
maybe_wrap() = delete;
570+
using type = func_expr<Func, As...>;
571+
static type
572+
wrap(const func_expr<Func, As...>& t)
573+
{
574+
return t;
575+
}
576+
};
577+
517578
template<typename Op, typename T>
518579
struct make_unary_expr
519580
{
@@ -542,6 +603,26 @@ struct make_binary_expr
542603
}
543604
};
544605

606+
template<typename Func, typename... Args>
607+
struct make_func_expr
608+
{
609+
make_func_expr() = delete;
610+
using type = func_expr< Func, typename maybe_wrap< Args >::type ... >;
611+
612+
static type
613+
create( Func& func, Args... args )
614+
{
615+
type out( func, maybe_wrap<Args>::wrap(args)... );
616+
return out;
617+
}
618+
};
619+
620+
template<typename Func, typename... Args>
621+
typename make_func_expr<Func, Args...>::type fn( Func& func, Args... args )
622+
{
623+
return make_func_expr<Func, Args...>::create( func, args... );
624+
}
625+
545626
template<typename L, typename R>
546627
typename std::enable_if<std::disjunction<is_expression<L>, is_expression<R>>::value,
547628
typename make_binary_expr<expr_op::LT, L, R>::type>::type

tests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ option( ENABLE_SIMD_IN_TESTS "Enable SIMD optimizations in tests (default)" ON )
1010

1111
if (ENABLE_SIMD_IN_TESTS)
1212
if (MSVC)
13-
add_compile_options(/arch:AVX)
13+
add_compile_options(/arch:AVX /bigobj)
1414
else()
1515
add_compile_options(-march=native)
1616
endif()

tests/mainframe_test_main.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,98 @@ TEST_CASE("prepend_column()", "[frame]")
11731173
}
11741174
}
11751175

1176+
TEST_CASE("expression fn", "[frame]")
1177+
{
1178+
frame<year_month_day, double, bool> f1;
1179+
f1.set_column_names("date", "temperature", "rain");
1180+
f1.push_back(2022_y / January / 2, 10.0, false);
1181+
f1.push_back(2022_y / January / 3, 11.1, true);
1182+
f1.push_back(2022_y / January / 4, 12.2, false);
1183+
f1.push_back(2022_y / January / 5, 13.3, false);
1184+
f1.push_back(2022_y / January / 6, 14.4, true);
1185+
f1.push_back(2022_y / January / 7, 15.5, false);
1186+
auto f2 = f1.append_column<double>( "tempceil", fn<double(double)>( ceil, _1 ) )
1187+
.append_column<double>( "temptrunc", fn<double(double)>( trunc, _1 ) )
1188+
.append_column<double>( "temptruncp.5", fn<double(double)>( trunc, _1+0.5 ) )
1189+
.append_column<double>( "temptruncp.5p", fn<double(double)>( trunc, _1 ) +0.5 )
1190+
.append_column<double>( "max", fn<double(double,double)>( ::fmax, _5, _6 ) );
1191+
dout << "expression fn f2:\n";
1192+
dout << f2;
1193+
auto it = f2.cbegin();
1194+
1195+
// Same as before
1196+
REQUIRE((it + 0)->at(_0) == 2022_y / January / 2);
1197+
REQUIRE((it + 0)->at(_1) == 10.0);
1198+
REQUIRE((it + 0)->at(_2) == false);
1199+
REQUIRE((it + 1)->at(_0) == 2022_y / January / 3);
1200+
REQUIRE((it + 1)->at(_1) == 11.1);
1201+
REQUIRE((it + 1)->at(_2) == true);
1202+
REQUIRE((it + 2)->at(_0) == 2022_y / January / 4);
1203+
REQUIRE((it + 2)->at(_1) == 12.2);
1204+
REQUIRE((it + 2)->at(_2) == false);
1205+
REQUIRE((it + 3)->at(_0) == 2022_y / January / 5);
1206+
REQUIRE((it + 3)->at(_1) == 13.3);
1207+
REQUIRE((it + 3)->at(_2) == false);
1208+
REQUIRE((it + 4)->at(_0) == 2022_y / January / 6);
1209+
REQUIRE((it + 4)->at(_1) == 14.4);
1210+
REQUIRE((it + 4)->at(_2) == true);
1211+
REQUIRE((it + 5)->at(_0) == 2022_y / January / 7);
1212+
REQUIRE((it + 5)->at(_1) == 15.5);
1213+
REQUIRE((it + 5)->at(_2) == false);
1214+
1215+
// clang-format off
1216+
// | _0 | _1 | _2 | _3 | _4 | _5 | _6 | _7
1217+
// | date | temperature | rain | tempceil | temptrunc | temptruncp.5 | temptruncp.5p | max
1218+
//__|____________|_____________|_______|__________|___________|______________|_______________|______
1219+
// 0| 2022-01-02 | 10 | false | 10 | 10 | 10 | 10.5 | 10.5
1220+
// 1| 2022-01-03 | 11.1 | true | 12 | 11 | 11 | 11.5 | 11.5
1221+
// 2| 2022-01-04 | 12.2 | false | 13 | 12 | 12 | 12.5 | 12.5
1222+
// 3| 2022-01-05 | 13.3 | false | 14 | 13 | 13 | 13.5 | 13.5
1223+
// 4| 2022-01-06 | 14.4 | true | 15 | 14 | 14 | 14.5 | 14.5
1224+
// 5| 2022-01-07 | 15.5 | false | 16 | 15 | 16 | 15.5 | 16
1225+
// clang-format on
1226+
1227+
// tempceil
1228+
REQUIRE((it + 0)->at(_3) == 10.0);
1229+
REQUIRE((it + 1)->at(_3) == 12.0);
1230+
REQUIRE((it + 2)->at(_3) == 13.0);
1231+
REQUIRE((it + 3)->at(_3) == 14.0);
1232+
REQUIRE((it + 4)->at(_3) == 15.0);
1233+
REQUIRE((it + 5)->at(_3) == 16.0);
1234+
1235+
// temptrunc
1236+
REQUIRE((it + 0)->at(_4) == 10.0);
1237+
REQUIRE((it + 1)->at(_4) == 11.0);
1238+
REQUIRE((it + 2)->at(_4) == 12.0);
1239+
REQUIRE((it + 3)->at(_4) == 13.0);
1240+
REQUIRE((it + 4)->at(_4) == 14.0);
1241+
REQUIRE((it + 5)->at(_4) == 15.0);
1242+
1243+
// temptruncp.5
1244+
REQUIRE((it + 0)->at(_5) == 10.0);
1245+
REQUIRE((it + 1)->at(_5) == 11.0);
1246+
REQUIRE((it + 2)->at(_5) == 12.0);
1247+
REQUIRE((it + 3)->at(_5) == 13.0);
1248+
REQUIRE((it + 4)->at(_5) == 14.0);
1249+
REQUIRE((it + 5)->at(_5) == 16.0);
1250+
1251+
// temptruncp.5p
1252+
REQUIRE((it + 0)->at(_6) == 10.5);
1253+
REQUIRE((it + 1)->at(_6) == 11.5);
1254+
REQUIRE((it + 2)->at(_6) == 12.5);
1255+
REQUIRE((it + 3)->at(_6) == 13.5);
1256+
REQUIRE((it + 4)->at(_6) == 14.5);
1257+
REQUIRE((it + 5)->at(_6) == 15.5);
1258+
1259+
// max
1260+
REQUIRE((it + 0)->at(_7) == 10.5);
1261+
REQUIRE((it + 1)->at(_7) == 11.5);
1262+
REQUIRE((it + 2)->at(_7) == 12.5);
1263+
REQUIRE((it + 3)->at(_7) == 13.5);
1264+
REQUIRE((it + 4)->at(_7) == 14.5);
1265+
REQUIRE((it + 5)->at(_7) == 16.0);
1266+
}
1267+
11761268
TEST_CASE("operator<<()", "[frame]")
11771269
{
11781270
frame<year_month_day, double, bool> f1;

0 commit comments

Comments
 (0)