Skip to content

Commit 00036b1

Browse files
fantesmergify[bot]
authored andcommitted
fix(torch/txt): correclty handle test sets in case of no splitting
1 parent 33aee72 commit 00036b1

File tree

2 files changed

+61
-2
lines changed

2 files changed

+61
-2
lines changed

src/txtinputfileconn.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,8 @@ namespace dd
430430
{
431431
DataEl<DDTxt> dtxt(this->_input_timeout);
432432
dtxt._ctype._ctfc = this;
433-
if (dtxt.read_element(_uris[i], this->_logger, i)
433+
_tests_txt.resize(i);
434+
if (dtxt.read_element(_uris[i], this->_logger, i - 1)
434435
|| (_txt.empty() && _db_fname.empty() && _ndbed == 0))
435436
{
436437
throw InputConnectorBadParamException("no data for text in "
@@ -440,7 +441,7 @@ namespace dd
440441
if (_ndbed == 0)
441442
{
442443
if (_db_fname.empty())
443-
_tests_txt[i].back()->_uri = _uris[i];
444+
_tests_txt[i - 1].back()->_uri = _uris[i];
444445
else
445446
return; // single db
446447
}

tests/ut-torchapi.cc

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -919,6 +919,64 @@ TEST(torchapi, service_train_txt_classification)
919919
ASSERT_TRUE(!fileops::file_exists(bert_train_repo + "solver-3.pt"));
920920
}
921921

922+
TEST(torchapi, service_train_txt_classification_nosplit)
923+
{
924+
setenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8", true);
925+
torch::manual_seed(torch_seed);
926+
at::globalContext().setDeterministic(true);
927+
928+
// create service
929+
JsonAPI japi;
930+
std::string sname = "txtserv";
931+
std::string jstr = "{\"mllib\":\"torch\",\"description\":\"bert\",\"type\":"
932+
"\"supervised\",\"model\":{\"repository\":\""
933+
+ bert_train_repo
934+
+ "\"},\"parameters\":{\"input\":{\"connector\":\"txt\","
935+
"\"ordered_words\":true,"
936+
"\"wordpiece_tokens\":true,\"punctuation_tokens\":true,"
937+
"\"sequence\":512},\"mllib\":{\"template\":\"bert\","
938+
"\"nclasses\":2,\"finetuning\":true,\"gpu\":true}}}";
939+
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
940+
ASSERT_EQ(created_str, joutstr);
941+
942+
// train
943+
std::string jtrainstr
944+
= "{\"service\":\"txtserv\",\"async\":false,\"parameters\":{"
945+
"\"mllib\":{\"solver\":{\"iterations\":3,\"base_lr\":"
946+
+ torch_lr
947+
+ ",\"iter_"
948+
"size\":2,\"solver_type\":\"ADAM\"},\"net\":{\"batch_size\":2}},"
949+
"\"input\":{\"seed\":12345,\"shuffle\":true},"
950+
"\"output\":{\"measure\":[\"f1\",\"acc\",\"mcll\",\"cmdiag\","
951+
"\"cmfull\"]}},\"data\":[\""
952+
+ bert_train_data + "\",\"" + bert_train_data + "\"]}";
953+
joutstr = japi.jrender(japi.service_train(jtrainstr));
954+
JDoc jd;
955+
std::cout << "joutstr=" << joutstr << std::endl;
956+
jd.Parse<rapidjson::kParseNanAndInfFlag>(joutstr.c_str());
957+
ASSERT_TRUE(!jd.HasParseError());
958+
ASSERT_EQ(201, jd["status"]["code"]);
959+
ASSERT_TRUE(abs(jd["body"]["measure"]["iteration"].GetDouble() - 3)
960+
< 0.00001)
961+
<< "iterations";
962+
// This assertion is non-deterministic
963+
// ASSERT_TRUE(jd["body"]["measure"]["train_loss"].GetDouble() > 1.0) <<
964+
// "train_loss";
965+
ASSERT_TRUE(jd["body"]["measure"]["acc"].GetDouble() <= 1) << "accuracy";
966+
ASSERT_TRUE(jd["body"]["measure"]["f1"].GetDouble() <= 1) << "f1";
967+
968+
std::unordered_set<std::string> lfiles;
969+
fileops::list_directory(bert_train_repo, true, false, false, lfiles);
970+
for (std::string ff : lfiles)
971+
{
972+
if (ff.find("checkpoint") != std::string::npos
973+
|| ff.find("solver") != std::string::npos)
974+
remove(ff.c_str());
975+
}
976+
ASSERT_TRUE(!fileops::file_exists(bert_train_repo + "checkpoint-3.pt"));
977+
ASSERT_TRUE(!fileops::file_exists(bert_train_repo + "solver-3.pt"));
978+
}
979+
922980
#endif
923981

924982
TEST(torchapi, service_train_csvts_nbeats)

0 commit comments

Comments
 (0)