@@ -919,6 +919,64 @@ TEST(torchapi, service_train_txt_classification)
919
919
ASSERT_TRUE (!fileops::file_exists (bert_train_repo + " solver-3.pt" ));
920
920
}
921
921
922
+ TEST (torchapi, service_train_txt_classification_nosplit)
923
+ {
924
+ setenv (" CUBLAS_WORKSPACE_CONFIG" , " :4096:8" , true );
925
+ torch::manual_seed (torch_seed);
926
+ at::globalContext ().setDeterministic (true );
927
+
928
+ // create service
929
+ JsonAPI japi;
930
+ std::string sname = " txtserv" ;
931
+ std::string jstr = " {\" mllib\" :\" torch\" ,\" description\" :\" bert\" ,\" type\" :"
932
+ " \" supervised\" ,\" model\" :{\" repository\" :\" "
933
+ + bert_train_repo
934
+ + " \" },\" parameters\" :{\" input\" :{\" connector\" :\" txt\" ,"
935
+ " \" ordered_words\" :true,"
936
+ " \" wordpiece_tokens\" :true,\" punctuation_tokens\" :true,"
937
+ " \" sequence\" :512},\" mllib\" :{\" template\" :\" bert\" ,"
938
+ " \" nclasses\" :2,\" finetuning\" :true,\" gpu\" :true}}}" ;
939
+ std::string joutstr = japi.jrender (japi.service_create (sname, jstr));
940
+ ASSERT_EQ (created_str, joutstr);
941
+
942
+ // train
943
+ std::string jtrainstr
944
+ = " {\" service\" :\" txtserv\" ,\" async\" :false,\" parameters\" :{"
945
+ " \" mllib\" :{\" solver\" :{\" iterations\" :3,\" base_lr\" :"
946
+ + torch_lr
947
+ + " ,\" iter_"
948
+ " size\" :2,\" solver_type\" :\" ADAM\" },\" net\" :{\" batch_size\" :2}},"
949
+ " \" input\" :{\" seed\" :12345,\" shuffle\" :true},"
950
+ " \" output\" :{\" measure\" :[\" f1\" ,\" acc\" ,\" mcll\" ,\" cmdiag\" ,"
951
+ " \" cmfull\" ]}},\" data\" :[\" "
952
+ + bert_train_data + " \" ,\" " + bert_train_data + " \" ]}" ;
953
+ joutstr = japi.jrender (japi.service_train (jtrainstr));
954
+ JDoc jd;
955
+ std::cout << " joutstr=" << joutstr << std::endl;
956
+ jd.Parse <rapidjson::kParseNanAndInfFlag >(joutstr.c_str ());
957
+ ASSERT_TRUE (!jd.HasParseError ());
958
+ ASSERT_EQ (201 , jd[" status" ][" code" ]);
959
+ ASSERT_TRUE (abs (jd[" body" ][" measure" ][" iteration" ].GetDouble () - 3 )
960
+ < 0.00001 )
961
+ << " iterations" ;
962
+ // This assertion is non-deterministic
963
+ // ASSERT_TRUE(jd["body"]["measure"]["train_loss"].GetDouble() > 1.0) <<
964
+ // "train_loss";
965
+ ASSERT_TRUE (jd[" body" ][" measure" ][" acc" ].GetDouble () <= 1 ) << " accuracy" ;
966
+ ASSERT_TRUE (jd[" body" ][" measure" ][" f1" ].GetDouble () <= 1 ) << " f1" ;
967
+
968
+ std::unordered_set<std::string> lfiles;
969
+ fileops::list_directory (bert_train_repo, true , false , false , lfiles);
970
+ for (std::string ff : lfiles)
971
+ {
972
+ if (ff.find (" checkpoint" ) != std::string::npos
973
+ || ff.find (" solver" ) != std::string::npos)
974
+ remove (ff.c_str ());
975
+ }
976
+ ASSERT_TRUE (!fileops::file_exists (bert_train_repo + " checkpoint-3.pt" ));
977
+ ASSERT_TRUE (!fileops::file_exists (bert_train_repo + " solver-3.pt" ));
978
+ }
979
+
922
980
#endif
923
981
924
982
TEST (torchapi, service_train_csvts_nbeats)
0 commit comments