diff --git a/include/tc/core/autodiff.h b/include/tc/core/autodiff.h new file mode 100644 index 000000000..310d0fddc --- /dev/null +++ b/include/tc/core/autodiff.h @@ -0,0 +1,26 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tc/lang/tree.h" + +#include + +namespace tc { + +std::string differentiate(const std::string& source); + +} // namespace tc diff --git a/include/tc/lang/sema.h b/include/tc/lang/sema.h index c39dda608..3df88c286 100644 --- a/include/tc/lang/sema.h +++ b/include/tc/lang/sema.h @@ -155,6 +155,8 @@ static inline TreeRef match_types(TreeRef a, TreeRef b) { /// - replace TK_APPLY with TK_BUILT_IN for built in functions /// - checks that all variables are defined, and creates index/reduction /// variable objects. +// - replaces augumented assignments that have no reduction variables +// with regular assignents struct Sema { std::unordered_map expr_to_type; diff --git a/include/tc/lang/tree_views.h b/include/tc/lang/tree_views.h index 1e26b8437..ebfcbc877 100644 --- a/include/tc/lang/tree_views.h +++ b/include/tc/lang/tree_views.h @@ -125,14 +125,19 @@ struct ListViewIterator { bool operator!=(const ListViewIterator& rhs) const { return it != rhs.it; } + bool operator==(const ListViewIterator& rhs) const { + return it == rhs.it; + } T operator*() const { return T(*it); } - void operator++() { + ListViewIterator& operator++() { ++it; + return *this; } - void operator--() { + ListViewIterator& operator--() { --it; + return *this; } private: diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 03db09a43..a9a132756 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -3,6 +3,7 @@ add_library( SHARED + autodiff.cc flags.cc mapping_options.cc mapping_options_cpp_printer.cc diff --git a/src/core/autodiff.cc b/src/core/autodiff.cc new file mode 100644 index 000000000..5d3cfcd76 --- /dev/null +++ b/src/core/autodiff.cc @@ -0,0 +1,387 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tc/core/autodiff.h" +#include "tc/core/tc2halide.h" +#include "tc/lang/parser.h" +#include "tc/lang/sema.h" +#include "tc/lang/tc_format.h" +#include "tc/lang/tree_views.h" + +#include +#include + +namespace tc { + +using namespace lang; + +static const lang::SourceRange dummyRange{std::make_shared(""), + 0, + 0}; + +int32_t getTcType(Halide::Type t) { + if (t.is_int()) { + switch (t.bits()) { + case 64: + return TK_INT64; + case 32: + return TK_INT32; + case 16: + return TK_INT16; + case 8: + return TK_INT8; + } + } else if (t.is_uint()) { + switch (t.bits()) { + case 64: + return TK_UINT64; + case 32: + return TK_UINT32; + case 16: + return TK_UINT16; + case 8: + return TK_UINT8; + } + } else if (t.is_float()) { + switch (t.bits()) { + case 64: + return TK_DOUBLE; + case 32: + return TK_FLOAT; + } + } + throw std::runtime_error("Unknown Halide type"); +} + +void findAccessedTensors( + std::unordered_set& read_only, + const TreeRef& tree) { + if (tree->kind() == TK_ACCESS) { + read_only.insert(Access(tree).name().name()); + } else { + for (const TreeRef& subtree : tree->trees()) { + findAccessedTensors(read_only, subtree); + } + } +} + +void assertNoWriteAfterRead(Def def) { + std::unordered_set read_only; + // Inputs are always read-only + for (Param input : def.params()) + read_only.insert(input.ident().name()); + for (Comprehension comp : def.statements()) { + findAccessedTensors(read_only, comp.rhs()); + auto lhs_name = comp.ident().name(); + if (read_only.count(lhs_name) > 0) + throw std::runtime_error( + "AD not supported in TCs that write to a value after reading it"); + } +} + +void findIndexVars( + std::unordered_set& index_vars, + const TreeRef& tree, + bool gather_idents) { + if (tree->kind() == TK_IDENT && gather_idents) { + index_vars.insert(Ident(tree).name()); + } else if (tree->kind() == TK_ACCESS) { + for (const TreeRef& idx : Access(tree).arguments()) { + findIndexVars(index_vars, idx, true); + } + } else if (tree->kind() == TK_BUILT_IN) { + // BuiltIn holds the name of a function as an ident, so we have to skip it + for (const TreeRef& subtree : BuiltIn(tree).arguments()) { + findIndexVars(index_vars, subtree, gather_idents); + } + } else { + for (const TreeRef& subtree : tree->trees()) { + findIndexVars(index_vars, subtree, gather_idents); + } + } +} + +// XXX: this is a bit of a fragile hack, and can easily break when the AST will +// get more idents in different nodes, but it's quite simple and the worst thing +// that can happen is that we will be too conservative and throw, so it's ok. +std::unordered_set usedIndexVars(Comprehension comp) { + std::unordered_set index_vars; + for (Ident idx : comp.indices()) + index_vars.insert(idx.name()); + findIndexVars(index_vars, comp.rhs(), false); + return index_vars; +} + +// This struct holds a lot of the information required to perform bookkeeping +// of gradient values. For example: +// - do we already have a writeable gradient for a value, or only a seed +// - should we use the seed value, or the writeable gradient in an expression +// - is the gradient implicitly zero now (because the value was overwritten) +struct GradInfo { + GradInfo(const ListView& primal_outputs) { + for (const Param& output : primal_outputs) { + primal_outputs_.insert(output.ident().name()); + } + } + + void addGradComprehension( + Ident primal_lhs_name, + ListView lhs_indices, + TreeRef rhs_expr) { + auto lhs_name = makeGradName(primal_lhs_name); + if (has_writeable_grad_.count(primal_lhs_name.name()) == 0) { + auto rhs_expr = primal_outputs_.count(primal_lhs_name.name()) > 0 + ? Access::create(dummyRange, seedNameOf(lhs_name), lhs_indices) + : Const::create( + dummyRange, + Number::create(0), + Compound::create(TK_FLOAT, dummyRange, {})); + + grad_comps_.push_back(Comprehension::create( + dummyRange, + lhs_name, + lhs_indices, + Compound::create('=', dummyRange, {}), + rhs_expr, + ListView::create(dummyRange, TreeList{}), + Compound::create(TK_OPTION, dummyRange, {}), + ListView::create(dummyRange, TreeList{}))); + has_writeable_grad_.insert(primal_lhs_name.name()); + } + grad_comps_.push_back(Comprehension::create( + dummyRange, + lhs_name, + lhs_indices, + Compound::create(TK_PLUS_EQ, dummyRange, {}), + rhs_expr, + ListView::create(dummyRange, TreeList{}), + Compound::create(TK_OPTION, dummyRange, {}), + ListView::create(dummyRange, TreeList{}))); + if (usedIndexVars(Comprehension(grad_comps_.back())) != + required_index_vars_) + throw std::runtime_error( + "Not all index variables are used in gradient comprehension. " + "AD will require range inference to support this case."); + } + + bool hasZeroGrad(const std::string& name) { + return has_zero_grad_.count(name) > 0; + } + void markZeroGrad(const std::string& name) { + has_zero_grad_.count(name); + } + + std::vector&& getGradComps() { + return std::move(grad_comps_); + } + + void requireAllIndexVarsOf(const Comprehension& comp) { + required_index_vars_ = usedIndexVars(comp); + } + + Ident gradNameOf(const Ident& primal_name) { + if (has_writeable_grad_.count(primal_name.name()) > 0) { + return makeGradName(primal_name); + } + return seedNameOf(primal_name); + } + Ident seedNameOf(const Ident& primal_name) { + return makeSeedName(makeGradName(primal_name)); + } + + private: + Ident makeGradName(const Ident& name) { + return Ident(Ident::create(dummyRange, std::string("d_") + name.name())); + } + + Ident makeSeedName(const Ident& name) { + return Ident(Ident::create(dummyRange, std::string("seed_") + name.name())); + } + + std::unordered_set required_index_vars_; + std::vector grad_comps_; + // Keys in these sets are always names of primal variables. + std::unordered_set primal_outputs_; + std::unordered_set has_writeable_grad_; + std::unordered_set has_zero_grad_; +}; + +void differentiateExpr( + GradInfo& grad_info, + lang::TreeRef expr, + lang::TreeRef grad_output_expr) { + using namespace lang; + switch (expr->kind()) { + case TK_ACCESS: { + Access acc{expr}; + grad_info.addGradComprehension( + acc.name(), acc.arguments(), grad_output_expr); + break; + } + case '+': { + differentiateExpr(grad_info, expr->tree(0), grad_output_expr); + differentiateExpr(grad_info, expr->tree(1), grad_output_expr); + break; + } + case '*': { + differentiateExpr( + grad_info, + expr->tree(0), + Compound::create( + '*', expr->range(), {grad_output_expr, expr->tree(1)})); + differentiateExpr( + grad_info, + expr->tree(1), + Compound::create( + '*', expr->range(), {grad_output_expr, expr->tree(0)})); + break; + } + case TK_CONST: { + // There's nothing we have to do to handle constants, because we don't + // differentiate w.r.t. them. + break; + } + default: + throw ErrorReport(expr) << "Unsupported expression kind in AD: " + << kindToString(expr->kind()); + } +} + +// XXX: Sema isn't nilpotent, so we have to reparse the source +std::vector inferOutputTypes(const std::string& source) { + auto halide_def = + tc2halide::translate(isl::with_exceptions::globalIslCtx(), source, true); + std::vector output_types; + for (const auto& halide_output : halide_def.outputs) { + std::vector dim_exprs; + for (int d = 0; d < halide_output.dimensions(); ++d) { + auto halide_constr = halide_output.parameter().extent_constraint(d); + if (auto* param = halide_constr.as()) { + dim_exprs.push_back(Ident::create(dummyRange, param->name)); + } else if (auto* num = halide_constr.as()) { + dim_exprs.push_back(Const::create( + dummyRange, + Number::create(num->value), + Compound::create(TK_INT32, dummyRange, {}))); + } else { + std::stringstream s; + s << "AD only supports TCs in which sizes of outputs can be expressed as " + "size parameters or constants. This is not the case for " + << halide_output.name() << " which has an inferred size of ("; + for (int d = 0; d < halide_output.dimensions(); ++d) { + s << halide_output.parameter().extent_constraint(d); + if (d != halide_output.dimensions() - 1) + s << ", "; + } + s << ")"; + throw std::runtime_error(s.str()); + } + } + + auto dim_sizes = + ListView::create(dummyRange, std::move(dim_exprs)); + auto scalar_type = + Compound::create(getTcType(halide_output.type()), dummyRange, {}); + output_types.push_back( + TensorType::create(dummyRange, scalar_type, dim_sizes)); + } + + return output_types; +} + +std::string differentiate(const std::string& source) { + // Parse and check the source + auto def = Def(Sema().checkFunction(Parser(source).parseFunction())); + assertNoWriteAfterRead(def); + + GradInfo grad_info{def.returns()}; + + // -------------------------------------------------------------------------- + // Prepare inputs of the gradient Def. + std::vector reverse_inputs; + auto output_types = inferOutputTypes(source); + auto returns = def.returns(); + for (Param input : def.params()) { + reverse_inputs.push_back(input); + } + for (size_t i = 0, num_returns = returns.size(); i < num_returns; ++i) { + reverse_inputs.push_back( + Param::create(dummyRange, returns[i].ident(), output_types.at(i))); + } + for (size_t i = 0, num_returns = returns.size(); i < num_returns; ++i) { + reverse_inputs.push_back(Param::create( + dummyRange, + grad_info.seedNameOf(returns[i].ident()), + output_types.at(i))); + } + + // -------------------------------------------------------------------------- + // Differentiate the body + auto body = def.statements(); + auto it = body.end(); + if (it == body.begin()) + throw std::runtime_error("empty body"); + do { + Comprehension comp = *(--it); + + int assign_kind = comp.assignment()->kind(); + if (assign_kind != '=' && assign_kind != TK_PLUS_EQ_B && + assign_kind != TK_PLUS_EQ) + throw ErrorReport(comp) + << "Only =, += and +=! assignments are supported in AD"; + if (comp.whereClauses().size() > 0 || comp.equivalent().present()) + throw ErrorReport(comp) + << "Comprehensions with range constraints or equivalent are not supported in AD"; + + // See note [Implicit zero gradients] below. + auto primal_output = comp.ident(); + if (grad_info.hasZeroGrad(primal_output.name())) + continue; + + grad_info.requireAllIndexVarsOf(comp); + auto grad_output_expr = Access::create( + dummyRange, grad_info.gradNameOf(primal_output), comp.indices()); + differentiateExpr(grad_info, comp.rhs(), grad_output_expr); + + // Note [Implicit zero gradients] + // If we see one of the overwriting assignments, then we know that all + // previous values of primal output didn't have any effect on TC outputs + // and so their gradients are implicitly zero. + if (assign_kind == '=' || assign_kind == TK_PLUS_EQ_B) { + grad_info.markZeroGrad(primal_output.name()); + } + } while (it != body.begin()); + + // -------------------------------------------------------------------------- + // Prepare outputs, create the gradient Def, and print it + auto inferred_type = Compound::create(TK_INFERRED, dummyRange, {}); + std::vector reverse_outputs; + for (Param input : def.params()) + reverse_outputs.push_back(Param::create( + dummyRange, grad_info.gradNameOf(input.ident()), inferred_type)); + + auto reverseDef = Def::create( + dummyRange, + Ident::create(dummyRange, "grad_" + def.name().name()), + ListView::create(dummyRange, std::move(reverse_inputs)), + ListView::create(dummyRange, std::move(reverse_outputs)), + ListView::create(dummyRange, grad_info.getGradComps())); + + std::ostringstream s; + tcFormat(s, reverseDef); + return s.str(); +} + +} // namespace tc diff --git a/src/lang/test_expected/general_ad.expected b/src/lang/test_expected/general_ad.expected new file mode 100644 index 000000000..522d77fab --- /dev/null +++ b/src/lang/test_expected/general_ad.expected @@ -0,0 +1,8 @@ +def grad_mm(double(M, K) A, double(K, N) B, double(M, N) O, double(M, N) W, double(M, N) seed_d_O, double(M, N) seed_d_W) -> (d_A, d_B) { + d_O(i, j) = seed_d_d_O(i, j) + d_O(i, j) += (seed_d_W(i, j) * 2) + d_A(i, k) = 0 + d_A(i, k) += (d_O(i, j) * B(k, j)) + d_B(k, j) = 0 + d_B(k, j) += (d_O(i, j) * A(i, k)) +} \ No newline at end of file diff --git a/tensor_comprehensions/pybinds/pybind_engine.cc b/tensor_comprehensions/pybinds/pybind_engine.cc index 3d53fee9c..aad9f588f 100644 --- a/tensor_comprehensions/pybinds/pybind_engine.cc +++ b/tensor_comprehensions/pybinds/pybind_engine.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include +#include #include #include @@ -25,6 +26,7 @@ #include "pybind_utils.h" #include "tc/aten/aten_compiler.h" +#include "tc/core/autodiff.h" #include "tc/core/cuda/cuda_compilation_cache.h" #include "tc/core/cuda/cuda_tc_executor.h" #include "tc/core/flags.h" @@ -39,6 +41,7 @@ namespace py = pybind11; using ATenCudaCompilationUnit = tc::ATenCompilationUnit; PYBIND11_MODULE(tc, m) { + m.def("_differentiate", differentiate); m.def("set_logtostderr", [](bool logtostderr) { FLAGS_logtostderr = logtostderr; }); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4c9ed04b8..4f62e308b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -41,12 +41,13 @@ set(CORE_TEST_FILES test_inference test_isl_functionality test_tc2halide + test_autodiff ) foreach(i ${CORE_TEST_FILES}) add_executable(${i} ${i}.cc) add_test(${i} ${i}) - target_link_libraries(${i} ${GOOGLE_LIBS} tc_core) + target_link_libraries(${i} tc_core ${GOOGLE_LIBS}) endforeach() ################################################################################ @@ -80,7 +81,7 @@ if (WITH_CUDA) ${i} ${GOOGLE_LIBS} - ${ATEN_LIBRARIES} + ${ATEN_LIBRARIES} tc_autotuner) endforeach() endif() diff --git a/test/lang_utils.h b/test/lang_utils.h new file mode 100644 index 000000000..97f0d3d0e --- /dev/null +++ b/test/lang_utils.h @@ -0,0 +1,110 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +const std::string expected_file_path = "src/lang/test_expected/"; + +static inline void barf(const char* fmt, ...) { + char msg[2048]; + va_list args; + va_start(args, fmt); + vsnprintf(msg, 2048, fmt, args); + va_end(args); + throw std::runtime_error(msg); +} + +#define ASSERT(cond) \ + if (__builtin_expect(!(cond), 0)) { \ + barf( \ + "%s:%u: %s: Assertion `%s` failed.", \ + __FILE__, \ + __LINE__, \ + __func__, \ + #cond); \ + } + +// note: msg must be a string literal +// node: In, ##__VA_ARGS '##' supresses the comma if __VA_ARGS__ is empty +#define ASSERTM(cond, msg, ...) \ + if (__builtin_expect(!(cond), 0)) { \ + barf( \ + "%s:%u: %s: Assertion `%s` failed: " msg, \ + __FILE__, \ + __LINE__, \ + __func__, \ + #cond, \ + ##__VA_ARGS__); \ + } + +void writeFile(const std::string& filename, const std::string& value) { + std::ofstream ofile(filename.c_str()); + ASSERT(ofile.good()); + ofile << value; +} + +bool readFile(const std::string& filename, std::string& v) { + std::ifstream ifile(filename.c_str()); + if (!ifile.good()) + return false; + std::stringstream input; + input << ifile.rdbuf(); + v = input.str(); + return true; +} + +bool acceptChanges = false; + +void assertEqual( + const std::string& expected_filename_, + const std::string& the_value) { + std::string expected_filename = expected_file_path + expected_filename_; + std::string expected_value; + if (acceptChanges) { + writeFile(expected_filename, the_value); + return; + } + if (!readFile(expected_filename, expected_value)) { + throw std::runtime_error("expect file not found: " + expected_filename); + } + if (the_value != expected_value) { + std::string output = expected_filename + "-actual"; + writeFile(output, the_value); + std::stringstream ss; + ss << expected_filename << " did not match. Run:\n diff -u " + << expected_filename << " " << output + << "\n to compare. Re-run with --accept to accept changes."; + throw std::runtime_error(ss.str()); + } +} + +void parseArgs(int argc, char** argv) { + std::vector args; + for (int i = 1; i < argc; i++) { + args.push_back(argv[i]); + } + for (auto& a : args) { + if (a == "--accept") { + acceptChanges = true; + } + } +} diff --git a/test/test_autodiff.cc b/test/test_autodiff.cc new file mode 100644 index 000000000..78a4a0ea6 --- /dev/null +++ b/test/test_autodiff.cc @@ -0,0 +1,46 @@ +/** + * Copyright (c) 2017-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "tc/core/autodiff.h" + +#include "lang_utils.h" + +using namespace tc; + +void generalTest() { + auto source = R"( +def mm(double(M, K) A, double(K, N) B) -> (O, W) { + O(i, j) +=! A(i, k) * B(k, j) + W(i, j) = O(i, j) * 2 +} +)"; + auto grad_source = differentiate(source); + assertEqual("general_ad.expected", grad_source); +} + +int main(int argc, char** argv) { + parseArgs(argc, argv); + generalTest(); +} diff --git a/test/test_lang.cc b/test/test_lang.cc index ce5f3bfa3..20b2a994f 100644 --- a/test/test_lang.cc +++ b/test/test_lang.cc @@ -22,6 +22,8 @@ #include #include +#include "lang_utils.h" + #include "tc/lang/canonicalize.h" #include "tc/lang/parser.h" #include "tc/lang/sema.h" @@ -29,79 +31,6 @@ using namespace lang; -const std::string expected_file_path = "src/lang/test_expected/"; - -static inline void barf(const char* fmt, ...) { - char msg[2048]; - va_list args; - va_start(args, fmt); - vsnprintf(msg, 2048, fmt, args); - va_end(args); - throw std::runtime_error(msg); -} - -#define ASSERT(cond) \ - if (__builtin_expect(!(cond), 0)) { \ - barf( \ - "%s:%u: %s: Assertion `%s` failed.", \ - __FILE__, \ - __LINE__, \ - __func__, \ - #cond); \ - } - -// note: msg must be a string literal -// node: In, ##__VA_ARGS '##' supresses the comma if __VA_ARGS__ is empty -#define ASSERTM(cond, msg, ...) \ - if (__builtin_expect(!(cond), 0)) { \ - barf( \ - "%s:%u: %s: Assertion `%s` failed: " msg, \ - __FILE__, \ - __LINE__, \ - __func__, \ - #cond, \ - ##__VA_ARGS__); \ - } - -void writeFile(const std::string& filename, const std::string& value) { - std::ofstream ofile(filename.c_str()); - ASSERT(ofile.good()); - ofile << value; -} -bool readFile(const std::string& filename, std::string& v) { - std::ifstream ifile(filename.c_str()); - if (!ifile.good()) - return false; - std::stringstream input; - input << ifile.rdbuf(); - v = input.str(); - return true; -} - -bool acceptChanges = false; -void assertEqual( - const std::string& expected_filename_, - const std::string& the_value) { - std::string expected_filename = expected_file_path + expected_filename_; - std::string expected_value; - if (acceptChanges) { - writeFile(expected_filename, the_value); - return; - } - if (!readFile(expected_filename, expected_value)) { - throw std::runtime_error("expect file not found: " + expected_filename); - } - if (the_value != expected_value) { - std::string output = expected_filename + "-actual"; - writeFile(output, the_value); - std::stringstream ss; - ss << expected_filename << " did not match. Run:\n diff -u " - << expected_filename << " " << output - << "\n to compare. Re-run with --accept to accept changes."; - throw std::runtime_error(ss.str()); - } -} - void assertParseEqual( const std::string& test_name, const std::string& text, @@ -149,7 +78,6 @@ std::string canonicalText(const std::string& text) { } void testTcFormat() { - static std::ios_base::Init initIostreams; auto source = R"(def fun2(float(B, N, M) X, float(B, M, K) Y) -> (Q) { Q(b, ii, j) += (((exp(X(b, ii, k)) * int(Y(b, k, j))) * 2.5) + 3) })"; @@ -160,15 +88,7 @@ void testTcFormat() { } int main(int argc, char** argv) { - std::vector args; - for (int i = 1; i < argc; i++) { - args.push_back(argv[i]); - } - for (auto& a : args) { - if (a == "--accept") { - acceptChanges = true; - } - } + parseArgs(argc, argv); { std::string s = "min min+max 1.4 .3 3 3.\n3e-3 .5e-7 3E-5 foobar";