diff --git a/src/apidata.h b/src/apidata.h index 9ec0e2f76..cd53b5a2f 100644 --- a/src/apidata.h +++ b/src/apidata.h @@ -38,6 +38,7 @@ #include #include #include +#include "oatpp/parser/json/mapping/ObjectMapper.hpp" namespace dd { @@ -288,6 +289,31 @@ namespace dd */ void toJDoc(JDoc &jd) const; + /** + * \brief converts APIData to oat++ DTO + */ + template inline std::shared_ptr createSharedDTO() const + { + rapidjson::Document d; + d.SetObject(); + toJDoc(reinterpret_cast(d)); + + rapidjson::StringBuffer buffer; + rapidjson::Writer, + rapidjson::UTF8<>, rapidjson::CrtAllocator, + rapidjson::kWriteNanAndInfFlag> + writer(buffer); + bool done = d.Accept(writer); + if (!done) + throw DataConversionException("JSON rendering failed"); + + std::shared_ptr object_mapper + = oatpp::parser::json::mapping::ObjectMapper::createShared(); + return object_mapper + ->readFromString>(buffer.GetString()) + .getPtr(); + } + /** * \brief converts APIData to rapidjson JSON value * @param jd JSON Document hosting the destination JSON value diff --git a/src/backends/ncnn/ncnnlib.cc b/src/backends/ncnn/ncnnlib.cc index 7bb8b291d..291f36344 100644 --- a/src/backends/ncnn/ncnnlib.cc +++ b/src/backends/ncnn/ncnnlib.cc @@ -22,7 +22,6 @@ #include "outputconnectorstrategy.h" #include #include -#include "utils/utils.hpp" // NCNN #include "ncnnlib.h" @@ -53,10 +52,10 @@ namespace dd { this->_libname = "ncnn"; _net = new ncnn::Net(); - _net->opt.num_threads = _threads; + _net->opt.num_threads = 1; _net->opt.blob_allocator = &_blob_pool_allocator; _net->opt.workspace_allocator = &_workspace_pool_allocator; - _net->opt.lightmode = _lightmode; + _net->opt.lightmode = true; } template _libname = "ncnn"; _net = tl._net; tl._net = nullptr; - _nclasses = tl._nclasses; - _threads = tl._threads; _timeserie = tl._timeserie; _old_height = tl._old_height; - _inputBlob = tl._inputBlob; - _outputBlob = tl._outputBlob; + _init_dto = tl._init_dto; } template ::init_mllib(const APIData &ad) { + _init_dto = ad.createSharedDTO(); + bool use_fp32 = (ad.has("datatype") && ad.get("datatype").get() == "fp32"); // default is fp16 @@ -124,35 +122,11 @@ namespace dd _old_height = this->_inputc.height(); _net->set_input_h(_old_height); - if (ad.has("nclasses")) - _nclasses = ad.get("nclasses").get(); - - if (ad.has("threads")) - _threads = ad.get("threads").get(); - else - _threads = dd_utils::my_hardware_concurrency(); - _timeserie = this->_inputc._timeserie; if (_timeserie) this->_mltype = "timeserie"; - if (ad.has("lightmode")) - { - _lightmode = ad.get("lightmode").get(); - _net->opt.lightmode = _lightmode; - } - - // setting the value of Input Layer - if (ad.has("inputblob")) - { - _inputBlob = ad.get("inputblob").get(); - } - // setting the final Output Layer - if (ad.has("outputblob")) - { - _outputBlob = ad.get("outputblob").get(); - } - + _net->opt.lightmode = _init_dto->lightmode; _blob_pool_allocator.set_size_compare_ratio(0.0f); _workspace_pool_allocator.set_size_compare_ratio(0.5f); model_type(this->_mlmodel._params, this->_mltype); @@ -212,33 +186,19 @@ namespace dd } APIData ad_output = ad.getobj("parameters").getobj("output"); - - // Get bbox - bool bbox = false; - if (ad_output.has("bbox")) - bbox = ad_output.get("bbox").get(); - - // Ctc model - bool ctc = false; - int blank_label = -1; - if (ad_output.has("ctc")) - { - ctc = ad_output.get("ctc").get(); - if (ctc) - { - if (ad_output.has("blank_label")) - blank_label = ad_output.get("blank_label").get(); - } - } + auto output_params + = ad_output.createSharedDTO(); // Extract detection or classification - int ret = 0; - std::string out_blob = _outputBlob; + std::string out_blob; + if (_init_dto->outputBlob != nullptr) + out_blob = _init_dto->outputBlob->std_str(); + if (out_blob.empty()) { - if (bbox == true) + if (output_params->bbox == true) out_blob = "detection_out"; - else if (ctc == true) + else if (output_params->ctc == true) out_blob = "probs"; else if (_timeserie) out_blob = "rnn_pred"; @@ -246,27 +206,14 @@ namespace dd out_blob = "prob"; } - std::vector vrad; - - // Get confidence_threshold - float confidence_threshold = 0.0; - if (ad_output.has("confidence_threshold")) - { - apitools::get_float(ad_output, "confidence_threshold", - confidence_threshold); - } - // Get best - int best = -1; - if (ad_output.has("best")) - { - best = ad_output.get("best").get(); - } - if (best == -1 || best > _nclasses) - best = _nclasses; + if (output_params->best == -1 || output_params->best > _init_dto->nclasses) + output_params->best = _init_dto->nclasses; + + std::vector vrad; - // for loop around batch size -#pragma omp parallel for num_threads(_threads) + // for loop around batch size +#pragma omp parallel for num_threads(*_init_dto->threads) for (size_t b = 0; b < inputc._ids.size(); b++) { std::vector probs; @@ -276,16 +223,16 @@ namespace dd APIData rad; ncnn::Extractor ex = _net->create_extractor(); - ex.set_num_threads(_threads); - ex.input(_inputBlob.c_str(), inputc._in.at(b)); + ex.set_num_threads(_init_dto->threads); + ex.input(_init_dto->inputBlob->c_str(), inputc._in.at(b)); - ret = ex.extract(out_blob.c_str(), inputc._out.at(b)); + int ret = ex.extract(out_blob.c_str(), inputc._out.at(b)); if (ret == -1) { throw MLLibInternalException("NCNN internal error"); } - if (bbox == true) + if (output_params->bbox == true) { std::string uri = inputc._ids.at(b); auto bit = inputc._imgs_size.find(uri); @@ -305,7 +252,7 @@ namespace dd for (int i = 0; i < inputc._out.at(b).h; i++) { const float *values = inputc._out.at(b).row(i); - if (values[1] < confidence_threshold) + if (values[1] < output_params->confidence_threshold) break; // output is sorted by confidence cats.push_back(this->_mlmodel.get_hcorresp(values[0])); @@ -323,7 +270,7 @@ namespace dd bboxes.push_back(ad_bbox); } } - else if (ctc == true) + else if (output_params->ctc == true) { int alphabet = inputc._out.at(b).w; int time_step = inputc._out.at(b).h; @@ -336,11 +283,11 @@ namespace dd } std::vector pred_label_seq; - int prev = blank_label; + int prev = output_params->blank_label; for (int t = 0; t < time_step; ++t) { int cur = pred_label_seq_with_blank[t]; - if (cur != prev && cur != blank_label) + if (cur != prev && cur != output_params->blank_label) pred_label_seq.push_back(cur); prev = cur; } @@ -388,12 +335,13 @@ namespace dd vec[i] = std::make_pair(cls_scores[i], i); } - std::partial_sort(vec.begin(), vec.begin() + best, vec.end(), + std::partial_sort(vec.begin(), vec.begin() + output_params->best, + vec.end(), std::greater>()); - for (int i = 0; i < best; i++) + for (int i = 0; i < output_params->best; i++) { - if (vec[i].first < confidence_threshold) + if (vec[i].first < output_params->confidence_threshold) continue; cats.push_back(this->_mlmodel.get_hcorresp(vec[i].second)); probs.push_back(vec[i].first); @@ -403,7 +351,7 @@ namespace dd rad.add("uri", inputc._ids.at(b)); rad.add("loss", 0.0); rad.add("cats", cats); - if (bbox == true) + if (output_params->bbox == true) rad.add("bboxes", bboxes); if (_timeserie) { @@ -423,8 +371,9 @@ namespace dd } // end for batch_size tout.add_results(vrad); - out.add("nclasses", this->_nclasses); - if (bbox == true) + int nclasses = this->_init_dto->nclasses; + out.add("nclasses", nclasses); + if (output_params->bbox == true) out.add("bbox", true); out.add("roi", false); out.add("multibox_rois", false); diff --git a/src/backends/ncnn/ncnnlib.h b/src/backends/ncnn/ncnnlib.h index 513d06cd8..63f4fea48 100644 --- a/src/backends/ncnn/ncnnlib.h +++ b/src/backends/ncnn/ncnnlib.h @@ -22,12 +22,15 @@ #ifndef NCNNLIB_H #define NCNNLIB_H +#include "apidata.h" +#include "utils/utils.hpp" + +#include "dto/ncnn.hpp" + // NCNN #include "net.h" #include "ncnnmodel.h" -#include "apidata.h" - namespace dd { template _init_dto; static ncnn::UnlockedPoolAllocator _blob_pool_allocator; static ncnn::PoolAllocator _workspace_pool_allocator; protected: - int _threads = 1; int _old_height = -1; - std::string _inputBlob = "data"; - std::string _outputBlob; }; + } #endif diff --git a/src/http/controller.hpp b/src/http/controller.hpp index 1a12791e7..0cf3360eb 100644 --- a/src/http/controller.hpp +++ b/src/http/controller.hpp @@ -26,12 +26,16 @@ #include #include +#include + #include "oatpp/web/server/api/ApiController.hpp" #include "oatpp/parser/json/mapping/ObjectMapper.hpp" #include "oatpp/core/macro/codegen.hpp" #include "oatpp/core/macro/component.hpp" +#include "apidata.h" #include "oatppjsonapi.h" +#include "http/dto/info.hpp" #include OATPP_CODEGEN_BEGIN(ApiController) @@ -61,11 +65,31 @@ class DedeController : public oatpp::web::server::api::ApiController } ENDPOINT("GET", "info", get_info, QUERIES(QueryParams, queryParams)) { - // TODO(sileht): why do serialize the query string to char* - // to later get again a APIData... - std::string jsonstr = _oja->uri_query_to_json(queryParams); - auto janswer = _oja->info(jsonstr); - return _oja->jdoc_to_response(janswer); + auto info_resp = InfoResponse::createShared(); + info_resp->head = InfoHead::createShared(); + info_resp->head->services = {}; + + auto qs_status = queryParams.get("status"); + bool status = false; + if (qs_status) + status = boost::lexical_cast(qs_status->std_str()); + + auto hit = _oja->_mlservices.begin(); + while (hit != _oja->_mlservices.end()) + { + // TODO(sileht): update visitor_info to return directly a Service() + JDoc jd; + jd.SetObject(); + mapbox::util::apply_visitor(dd::visitor_info(status), (*hit).second) + .toJDoc(jd); + auto json_str = _oja->jrender(jd); + auto service_info + = getDefaultObjectMapper()->readFromString>( + json_str.c_str()); + info_resp->head->services->emplace_back(service_info); + ++hit; + } + return createDtoResponse(Status::CODE_200, info_resp); } ENDPOINT_INFO(get_service) diff --git a/src/http/dto/info.hpp b/src/http/dto/info.hpp new file mode 100644 index 000000000..a05d0d160 --- /dev/null +++ b/src/http/dto/info.hpp @@ -0,0 +1,83 @@ +/** + * DeepDetect + * Copyright (c) 2020 Jolibrain SASU + * Author: Mehdi Abaakouk + * + * This file is part of deepdetect. + * + * deepdetect is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * deepdetect is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with deepdetect. If not, see . + */ + +#ifndef HTTP_DTO_INFO_H +#define HTTP_DTO_INFO_H + +#include "dd_config.h" +#include "oatpp/core/Types.hpp" +#include "oatpp/core/macro/codegen.hpp" + +#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section + +class Service : public oatpp::DTO +{ + DTO_INIT(Service, DTO /* extends */) + + DTO_FIELD(String, name); + DTO_FIELD(String, description); + DTO_FIELD(String, mllib); + DTO_FIELD(String, mltype); + DTO_FIELD(Boolean, predict) = false; + DTO_FIELD(Boolean, training) = false; +}; + +class InfoHead : public oatpp::DTO +{ + DTO_INIT(InfoHead, DTO /* extends */) + DTO_FIELD(String, method) = "/info"; + + // Why this is not in body ? + DTO_FIELD(String, build_type, "build-type") = BUILD_TYPE; + DTO_FIELD(String, version) = GIT_VERSION; + DTO_FIELD(String, branch) = GIT_BRANCH; + DTO_FIELD(String, commit) = GIT_COMMIT_HASH; + DTO_FIELD(String, compile_flags) = COMPLIE_FLAGS; + DTO_FIELD(String, deps_version) = DEPS_VERSION; + DTO_FIELD(List>, services); +}; + +class InfoBody : public oatpp::DTO +{ + DTO_INIT(InfoBody, DTO /* extends */) +}; + +class Status : public oatpp::DTO +{ + DTO_INIT(Status, DTO /* extends */) + DTO_FIELD(Int32, code); + DTO_FIELD(String, msg); + DTO_FIELD(Int32, dd_code); + DTO_FIELD(String, dd_msg); +}; + +class InfoResponse : public oatpp::DTO +{ + DTO_INIT(InfoResponse, DTO /* extends */) + DTO_FIELD(String, dd_msg); + DTO_FIELD(Object, status); + DTO_FIELD(Object, head); + DTO_FIELD(Object, body); +}; + +#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section + +#endif // HTTP_DTO_INFO_H diff --git a/src/http/dto/predict.hpp b/src/http/dto/predict.hpp new file mode 100644 index 000000000..1cef4a05d --- /dev/null +++ b/src/http/dto/predict.hpp @@ -0,0 +1,56 @@ +/** + * DeepDetect + * Copyright (c) 2021 Jolibrain SASU + * Author: Mehdi Abaakouk + * + * This file is part of deepdetect. + * + * deepdetect is free software: you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * deepdetect is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with deepdetect. If not, see . + */ + +#ifndef HTTP_DTO_PREDICT_H +#define HTTP_DTO_PREDICT_H +#include "dd_config.h" +#include "oatpp/core/Types.hpp" +#include "oatpp/core/macro/codegen.hpp" + +#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section + +class PredictOutputParametersDto : public oatpp::DTO +{ + DTO_INIT(PredictOutputParametersDto, DTO /* extends */) + + /* ncnn */ + DTO_FIELD(Boolean, bbox) = false; + DTO_FIELD(Boolean, ctc) = false; + DTO_FIELD(Int32, blank_label) = -1; + DTO_FIELD(Float32, confidence_threshold) = 0.0; + + /* ncnn && supervised init && supervised predict */ + DTO_FIELD(Int32, best) = -1; + + /* output supervised init */ + DTO_FIELD(Boolean, nclasses) = false; // Looks like a bug ? + + /* output supervised predict */ + DTO_FIELD(Boolean, index) = false; + DTO_FIELD(Boolean, build_index) = false; + DTO_FIELD(Boolean, search) = false; + DTO_FIELD(Int32, search_nn); + DTO_FIELD(Int32, nprobe); +}; + +#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section + +#endif diff --git a/src/oatppjsonapi.cc b/src/oatppjsonapi.cc index 053b75d86..ece668644 100644 --- a/src/oatppjsonapi.cc +++ b/src/oatppjsonapi.cc @@ -264,7 +264,10 @@ namespace dd auto router = components.httpRouter.getObject(); - auto dedeController = DedeController::createShared(this); + std::shared_ptr defaultObjectMapper + = oatpp::parser::json::mapping::ObjectMapper::createShared(); + auto dedeController + = DedeController::createShared(this, defaultObjectMapper); dedeController->addEndpointsToRouter(router); auto docEndpoints = oatpp::swagger::Controller::Endpoints::createShared(); diff --git a/src/supervisedoutputconnector.h b/src/supervisedoutputconnector.h index fce00d0f7..a0b1f2ab1 100644 --- a/src/supervisedoutputconnector.h +++ b/src/supervisedoutputconnector.h @@ -23,6 +23,8 @@ #define SUPERVISEDOUTPUTCONNECTOR_H #define TS_METRICS_EPSILON 1E-2 +#include "http/dto/predict.hpp" + template bool SortScorePairDescend(const std::pair &pair1, const std::pair &pair2) @@ -161,10 +163,11 @@ namespace dd void init(const APIData &ad) { APIData ad_out = ad.getobj("parameters").getobj("output"); - if (ad_out.has("best")) - _best = ad_out.get("best").get(); + auto output_params + = ad_out.createSharedDTO(); + _best = output_params->best; if (_best == -1) - _best = ad_out.get("nclasses").get(); + _best = output_params->nclasses; } /** @@ -242,13 +245,13 @@ namespace dd * @param ad_out output data object * @param bcats supervised output connector */ - void best_cats(const APIData &ad_out, SupervisedOutput &bcats, + void best_cats(SupervisedOutput &bcats, const int &output_param_best, const int &nclasses, const bool &has_bbox, const bool &has_roi, const bool &has_mask) const { int best = _best; - if (ad_out.has("best")) - best = ad_out.get("best").get(); + if (output_param_best != -1) + best = output_param_best; if (best == -1) best = nclasses; if (!has_bbox && !has_roi && !has_mask) @@ -399,6 +402,8 @@ namespace dd */ void finalize(const APIData &ad_in, APIData &ad_out, MLModel *mlm) { + auto output_params = ad_in.createSharedDTO(); + #ifndef USE_SIMSEARCH (void)mlm; #endif @@ -443,12 +448,13 @@ namespace dd } if (!timeseries) - best_cats(ad_in, bcats, nclasses, has_bbox, has_roi, has_mask); + best_cats(bcats, output_params->best, nclasses, has_bbox, has_roi, + has_mask); std::unordered_set indexed_uris; #ifdef USE_SIMSEARCH // index - if (ad_in.has("index") && ad_in.get("index").get()) + if (output_params->index) { // check whether index has been created if (!mlm->_se) @@ -553,7 +559,7 @@ namespace dd } // build index - if (ad_in.has("build_index") && ad_in.get("build_index").get()) + if (output_params->build_index) { if (mlm->_se) mlm->build_index(); @@ -562,7 +568,7 @@ namespace dd } // search - if (ad_in.has("search") && ad_in.get("search").get()) + if (output_params->search) { // check whether index has been created if (!mlm->_se) @@ -582,11 +588,11 @@ namespace dd int search_nn = _best; if (has_roi) search_nn = _search_nn; - if (ad_in.has("search_nn")) - search_nn = ad_in.get("search_nn").get(); + if (output_params->search_nn) + search_nn = output_params->search_nn; #ifdef USE_FAISS - if (ad_in.has("nprobe")) - mlm->_se->_tse->_nprobe = ad_in.get("nprobe").get(); + if (output_params->nprobe) + mlm->_se->_tse->_nprobe = output_params->nprobe; #endif if (!has_roi) { diff --git a/tests/ut-oatpp.cc b/tests/ut-oatpp.cc index b31a086ad..72e8b502d 100644 --- a/tests/ut-oatpp.cc +++ b/tests/ut-oatpp.cc @@ -28,7 +28,7 @@ const std::string serv = "very_long_label_service_name_with_😀_inside"; const std::string serv2 = "myserv2"; -#ifdef CPU_ONLY +#if defined(CPU_ONLY) || defined(USE_CAFFE_CPU_ONLY) static std::string iterations_mnist = "10"; #else static std::string iterations_mnist = "10000"; @@ -158,7 +158,7 @@ void test_train(std::shared_ptr client) d.Parse(message.get()->c_str()); ASSERT_TRUE(d.HasMember("body")); ASSERT_TRUE(d["body"].HasMember("measure")); -#ifdef CPU_ONLY +#if defined(CPU_ONLY) || defined(USE_CAFFE_CPU_ONLY) ASSERT_EQ(9, d["body"]["measure"]["iteration"].GetDouble()); #else ASSERT_EQ(9999, d["body"]["measure"]["iteration"].GetDouble()); @@ -297,8 +297,8 @@ void test_multiservices(std::shared_ptr client) d.Parse(jstr.c_str()); ASSERT_TRUE(d["head"].HasMember("services")); ASSERT_EQ(2, d["head"]["services"].Size()); - ASSERT_TRUE(jstr.find("\"name\":\"" + serv + "\"") != std::string::npos); - ASSERT_TRUE(jstr.find("\"name\":\"myserv2\"") != std::string::npos); + ASSERT_EQ(serv2, d["head"]["services"][0]["name"].GetString()); + ASSERT_EQ(serv, d["head"]["services"][1]["name"].GetString()); // remove services and trained model files response = client->delete_services(serv.c_str(), "lib"); @@ -334,7 +334,8 @@ void test_concurrency(std::shared_ptr client) std::string train_post = "{\"service\":\"" + serv + "\",\"async\":true,\"parameters\":{\"mllib\":{\"gpu\":true," - "\"solver\":{\"iterations\":10000}}}}"; + "\"solver\":{\"iterations\": " + + iterations_mnist + "}}}}"; response = client->post_train(train_post.c_str()); message = response->readBodyToString(); ASSERT_TRUE(message != nullptr); @@ -365,13 +366,14 @@ void test_concurrency(std::shared_ptr client) d.Parse(jstr.c_str()); ASSERT_TRUE(d["head"].HasMember("services")); ASSERT_EQ(2, d["head"]["services"].Size()); - ASSERT_TRUE(jstr.find("\"name\":\"" + serv + "\"") != std::string::npos); - ASSERT_TRUE(jstr.find("\"name\":\"" + serv2 + "\"") != std::string::npos); + ASSERT_EQ(serv2, d["head"]["services"][0]["name"].GetString()); + ASSERT_EQ(serv, d["head"]["services"][1]["name"].GetString()); // train async second job train_post = "{\"service\":\"" + serv2 + "\",\"async\":true,\"parameters\":{\"mllib\":{\"gpu\":true," - "\"solver\":{\"iterations\":10000}}}}"; + "\"solver\":{\"iterations\":" + + iterations_mnist + "}}}}"; response = client->post_train(train_post.c_str()); message = response->readBodyToString(); ASSERT_TRUE(message != nullptr); diff --git a/tests/ut-oatpp.h b/tests/ut-oatpp.h index 335c1cf88..358365629 100644 --- a/tests/ut-oatpp.h +++ b/tests/ut-oatpp.h @@ -140,7 +140,10 @@ class DedeControllerTest : public oatpp::test::UnitTest dd::OatppJsonAPI oja; TestComponent component; oatpp::test::web::ClientServerTestRunner runner; - runner.addController(std::make_shared(&oja, nullptr)); + std::shared_ptr defaultObjectMapper + = oatpp::parser::json::mapping::ObjectMapper::createShared(); + runner.addController( + std::make_shared(&oja, defaultObjectMapper)); runner.run( [this, &runner] { OATPP_COMPONENT(