diff --git a/metagraph/api/python/metagraph/client.py b/metagraph/api/python/metagraph/client.py index faaaa06b1f..b7027c4cb0 100644 --- a/metagraph/api/python/metagraph/client.py +++ b/metagraph/api/python/metagraph/client.py @@ -44,6 +44,7 @@ def search(self, sequence: Union[str, Iterable[str]], abundance_sum: bool = False, query_counts: bool = False, query_coords: bool = False, + graphs: Union[None, List[str]] = None, align: bool = False, **align_params) -> Tuple[JsonDict, str]: """See parameters for alignment `align_params` in align()""" @@ -80,6 +81,8 @@ def search(self, sequence: Union[str, Iterable[str]], "abundance_sum": abundance_sum, "query_counts": query_counts, "query_coords": query_coords} + if graphs is not None: + param_dict["graphs"] = graphs search_results = self._json_seq_query(sequence, param_dict, "search") @@ -178,6 +181,7 @@ def search(self, sequence: Union[str, Iterable[str]], abundance_sum: bool = False, query_counts: bool = False, query_coords: bool = False, + graphs: Union[None, List[str]] = None, align: bool = False, **align_params) -> pd.DataFrame: """ @@ -195,6 +199,8 @@ def search(self, sequence: Union[str, Iterable[str]], :type query_counts: bool :param query_coords: Query k-mer coordinates :type query_coords: bool + :param graphs: List of graph names to search. If None, search all graphs + :type graphs: Union[None, List[str]] :param align: Align the query sequence to the joint graph and query labels for that alignment instead of the original sequence :type align: bool :param align_params: The parameters for alignment (see method align()) @@ -207,7 +213,7 @@ def search(self, sequence: Union[str, Iterable[str]], json_obj = self._json_client.search(sequence, top_labels, discovery_fraction, with_signature, abundance_sum, query_counts, query_coords, - align, **align_params) + graphs, align, **align_params) return helpers.df_from_search_result(json_obj) @@ -272,6 +278,7 @@ def search(self, sequence: Union[str, Iterable[str]], abundance_sum: bool = False, query_counts: bool = False, query_coords: bool = False, + graphs: Union[None, List[str]] = None, align: bool = False, **align_params) -> Dict[str, Union[pd.DataFrame, Future]]: """ @@ -290,7 +297,7 @@ def search(self, sequence: Union[str, Iterable[str]], result[name] = graph_client.search(sequence, top_labels, discovery_fraction, with_signature, abundance_sum, query_counts, query_coords, - align, **align_params) + graphs, align, **align_params) return result @@ -305,7 +312,7 @@ def search(self, sequence: Union[str, Iterable[str]], futures[name] = executor.submit(graph_client.search, sequence, top_labels, discovery_fraction, with_signature, abundance_sum, query_counts, query_coords, - align, **align_params) + graphs, align, **align_params) print(f'Made {len(self.graphs)} requests with {num_processes} threads...') diff --git a/metagraph/integration_tests/test_api.py b/metagraph/integration_tests/test_api.py index 4ec637cbe8..9f0dbb5540 100644 --- a/metagraph/integration_tests/test_api.py +++ b/metagraph/integration_tests/test_api.py @@ -9,12 +9,16 @@ import pandas as pd -from metagraph.client import GraphClientJson, MultiGraphClient +from metagraph.client import GraphClientJson, MultiGraphClient, GraphClient from concurrent.futures import Future from parameterized import parameterized, parameterized_class +from itertools import product from base import PROTEIN_MODE, TestingBase, METAGRAPH, TEST_DATA_DIR +GRAPH_MODES = ['basic'] if PROTEIN_MODE else ['basic', 'canonical', 'primary'] + + class TestAPIBase(TestingBase): @classmethod def setUpClass(cls, fasta_path, mode='basic', anno_repr='column'): @@ -53,14 +57,13 @@ def tearDownClass(cls): def _start_server(self, graph, annotation): construct_command = f'{METAGRAPH} server_query -i {graph} -a {annotation} \ - --port {self.port} --address {self.host} -p {2}' + --port {self.port} --address {self.host} -p 2' return subprocess.Popen(shlex.split(construct_command)) # No canonical mode for Protein alphabets -@parameterized_class(('mode',), - input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)])) +@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES]) class TestAPIRaw(TestAPIBase): @classmethod def setUpClass(cls): @@ -249,8 +252,7 @@ def test_api_raw_search_no_count_support(self): # No canonical mode for Protein alphabets -@parameterized_class(('mode',), - input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)])) +@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES]) class TestAPIClient(TestAPIBase): graph_name = 'test_graph' @@ -356,10 +358,172 @@ def test_api_search_no_count_support(self): ret = self.graph_client.search(self.sample_query, parallel=False, discovery_fraction=0.01, abundance_sum=True) + @unittest.expectedFailure + def test_api_search_bad_graphs(self): + df = self.graph_client.search(self.sample_query, parallel=False, graphs=['X']) + + +# No canonical mode for Protein alphabets +@parameterized_class(('mode', 'threads_each'), + input_values=[(mode, 1) for mode in GRAPH_MODES] + [('basic', 4)]) +class TestAPIClientMultiple(TestingBase): + @classmethod + def setUpClass( + cls, + fasta_paths=[ + TEST_DATA_DIR + '/transcripts_100.fa', + TEST_DATA_DIR + '/transcripts_1000.fa' + ], + anno_repr='column', + ): + super().setUpClass() + + graph_paths = [ + cls.tempdir.name + '/graph_1.dbg', + cls.tempdir.name + '/graph_2.dbg', + ] + annotation_path_bases = [ + cls.tempdir.name + '/annotation_1', + cls.tempdir.name + '/annotation_2', + ] + + cls._build_graph(fasta_paths[0], graph_paths[0], 6, 'succinct', mode=cls.mode) + cls._annotate_graph(fasta_paths[0], graph_paths[0], annotation_path_bases[0], anno_repr) + + cls._build_graph(fasta_paths[1], graph_paths[1], 6, 'succinct', mode=cls.mode) + cls._annotate_graph(fasta_paths[1], graph_paths[1], annotation_path_bases[1], anno_repr) + + cls.mult = 1 if cls.threads_each < 2 else 200 # duplicate graphs so that there are more to query + + graphs_file = cls.tempdir.name + '/graphs.txt' + with open(graphs_file, 'w') as f: + f.write('\n'.join([ + f'G1,{graph_paths[0]},{annotation_path_bases[0]}.{anno_repr}.annodbg', + f'G2,{graph_paths[0]},{annotation_path_bases[0]}.{anno_repr}.annodbg', + f'G2,{graph_paths[0]},{annotation_path_bases[0]}.{anno_repr}.annodbg', + f'G3,{graph_paths[1]},{annotation_path_bases[1]}.{anno_repr}.annodbg', + ] * cls.mult)) + + cls.host = '127.0.0.1' + os.environ['NO_PROXY'] = cls.host + cls.port = 3456 + num_retries = 100 + while num_retries > 0: + cls.server_process = subprocess.Popen(shlex.split( + f'{METAGRAPH} server_query {graphs_file} \ + --port {cls.port} --address {cls.host} -p 2 --threads-each {cls.threads_each}' + )) + try: + cls.server_process.wait(timeout=1) + except subprocess.TimeoutExpired: + break + cls.port += 1 + num_retries -= 1 + if num_retries == 0: + raise "Couldn't start server" + + wait_time_sec = 1 + print("Waiting {} sec for the server (PID {}) to start up".format( + wait_time_sec, cls.server_process.pid), flush=True) + time.sleep(wait_time_sec) + + cls.graph_client = MultiGraphClient() + cls.graph_name = 'G1,G2x2,G3' + cls.graph_client.add_graph(cls.host, cls.port, cls.graph_name) + cls.client = GraphClient(cls.host, cls.port, cls.graph_name) + + cls.sample_query = 'CCTCTGTGGAATCCAATCTGTCTTCCATCCTGCGTGGCCGAGGG' + # 'canonical' and 'primary' graphs represent more k-mers than 'basic', so + # they get more matches + cls.expected_rows_1 = (98 if cls.mode == 'basic' else 99) * cls.mult + cls.expected_matches_1 = (840 if cls.mode == 'basic' else 1381) * cls.mult + + cls.expected_rows_2 = cls.expected_rows_1 * 2 + cls.expected_matches_2 = cls.expected_matches_1 * 2 + + cls.expected_rows_3 = (100) * cls.mult + cls.expected_matches_3 = (2843 if cls.mode == 'basic' else 3496) * cls.mult + + @classmethod + def tearDownClass(cls): + cls.server_process.kill() + + def test_api_query_df_multiple_1(self): + ret = self.graph_client.search(self.sample_query, parallel=False, + discovery_threshold=0.01, graphs=['G1']) + df = ret[self.graph_name] + self.assertEqual((self.expected_rows_1, 3), df.shape) + self.assertEqual(df['kmer_count'].sum(), self.expected_matches_1) + + def test_api_query_df_multiple_2(self): + ret = self.graph_client.search(self.sample_query, parallel=False, + discovery_threshold=0.01, graphs=['G2']) + df = ret[self.graph_name] + self.assertEqual((self.expected_rows_2, 3), df.shape) + self.assertEqual(df['kmer_count'].sum(), self.expected_matches_2) + + def test_api_query_df_multiple_3(self): + ret = self.graph_client.search(self.sample_query, parallel=False, + discovery_threshold=0.01, graphs=['G3']) + df = ret[self.graph_name] + self.assertEqual((self.expected_rows_3, 3), df.shape) + self.assertEqual(df['kmer_count'].sum(), self.expected_matches_3) + + def test_api_query_df_multiple_12(self): + ret = self.graph_client.search(self.sample_query, parallel=False, + discovery_threshold=0.01, graphs=['G1', 'G2']) + df = ret[self.graph_name] + self.assertEqual((self.expected_rows_1 + self.expected_rows_2, 3), df.shape) + self.assertEqual(df['kmer_count'].sum(), + self.expected_matches_1 + self.expected_matches_2) + + def test_api_query_df_multiple_all(self): + ret = self.graph_client.search(self.sample_query, parallel=False, + discovery_threshold=0.01, graphs=['G1', 'G2', 'G3']) + df = ret[self.graph_name] + self.assertEqual((self.expected_rows_1 + self.expected_rows_2 + self.expected_rows_3, 3), df.shape) + self.assertEqual(df['kmer_count'].sum(), + self.expected_matches_1 + self.expected_matches_2 + self.expected_matches_3) + + def test_api_query_df_multiple(self): + ret = self.graph_client.search(self.sample_query, parallel=False, + discovery_threshold=0.01) + df = ret[self.graph_name] + self.assertEqual((self.expected_rows_1 + self.expected_rows_2 + self.expected_rows_3, 3), df.shape) + self.assertEqual(df['kmer_count'].sum(), + self.expected_matches_1 + self.expected_matches_2 + self.expected_matches_3) + + @unittest.expectedFailure + def test_api_query_df_multiple_bad(self): + df = self.graph_client.search(self.sample_query, parallel=False, graphs=['G1', 'X']) + + def test_api_stats_multiple_graphs(self): + """Test /stats endpoint with multiple graphs returns aggregated stats""" + ret = self.client._json_client.stats() + + # Should have annotation section with aggregated labels + self.assertIn("annotation", ret.keys()) + self.assertIn("labels", ret["annotation"]) + + # Total labels should be sum of all graphs + self.assertEqual(ret["annotation"]["labels"], 1300 * self.mult) + + def test_api_column_labels_multiple_graphs(self): + """Test /column_labels endpoint with multiple graphs returns all labels""" + ret = self.graph_client.column_labels() + + self.assertIn(self.graph_name, ret.keys()) + label_list = ret[self.graph_name] + # Should have labels from all graphs combined + self.assertEqual(len(label_list), 1300 * self.mult) + # All labels should start with 'ENST' (from the test data) + self.assertTrue(all(l.startswith('ENST') for l in label_list)) + # Check that the total number of deduplicated labels equals to 1000 + self.assertEqual(len(set(label_list)), 1000) + # No canonical mode for Protein alphabets -@parameterized_class(('mode',), - input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)])) +@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES]) class TestAPIJson(TestAPIBase): graph_name = 'test_graph' @@ -417,8 +581,7 @@ def test_api_stats(self): # No canonical mode for Protein alphabets -@parameterized_class(('mode',), - input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)])) +@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES]) class TestAPIClientWithProperties(TestAPIBase): """ Testing whether properties encoded in sample name are properly processed @@ -449,8 +612,7 @@ def test_api_search_property_df_empty(self): # No canonical mode for Protein alphabets -@parameterized_class(('mode',), - input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)])) +@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES]) class TestAPIClientWithCoordinates(TestAPIBase): """ Testing whether API works well given coordinate aware annotations @@ -509,8 +671,7 @@ def test_api_simple_query_coords_df(self): # No canonical mode for Protein alphabets -@parameterized_class(('mode',), - input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)])) +@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES]) class TestAPIClientWithCounts(TestAPIBase): """ Testing whether API works well given k-mer count aware annotations @@ -564,8 +725,7 @@ def test_api_search_no_coordinate_support(self): # No canonical mode for Protein alphabets -@parameterized_class(('mode',), - input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)])) +@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES]) class TestAPIClientParallel(TestAPIBase): """ Testing whether or not parallel requests work diff --git a/metagraph/src/cli/config/config.cpp b/metagraph/src/cli/config/config.cpp index f0c65d2ecb..c9d57a4853 100644 --- a/metagraph/src/cli/config/config.cpp +++ b/metagraph/src/cli/config/config.cpp @@ -61,6 +61,7 @@ Config::Config(int argc, char *argv[]) { identity = QUERY; } else if (!strcmp(argv[1], "server_query")) { identity = SERVER_QUERY; + num_top_labels = 10'000; } else if (!strcmp(argv[1], "transform")) { identity = TRANSFORM; } else if (!strcmp(argv[1], "transform_anno")) { @@ -543,7 +544,7 @@ Config::Config(int argc, char *argv[]) { if (count_kmers || query_presence) map_sequences = true; - if ((identity == QUERY || identity == SERVER_QUERY) && infbase.empty()) + if (identity == QUERY && infbase.empty()) print_usage_and_exit = true; if ((identity == QUERY || identity == SERVER_QUERY || identity == ALIGN) @@ -575,9 +576,15 @@ Config::Config(int argc, char *argv[]) { if (identity == EXTEND && infbase.empty()) print_usage_and_exit = true; - if ((identity == QUERY || identity == SERVER_QUERY) && infbase_annotators.size() != 1) + if (identity == QUERY && infbase_annotators.size() != 1) print_usage_and_exit = true; + if (identity == SERVER_QUERY + && (fnames.size() > 1 + || (fnames.size() && (infbase.size() || infbase_annotators.size())) + || (fnames.empty() && (infbase.empty() || infbase_annotators.size() != 1)))) + print_usage_and_exit = true; // only one of fnames or (infbase & annotator) must be used + if ((identity == TRANSFORM || identity == CLEAN || identity == ASSEMBLE @@ -1369,7 +1376,10 @@ if (advanced) { } } break; case SERVER_QUERY: { - fprintf(stderr, "Usage: %s server_query -i -a [options]\n\n", prog_name.c_str()); + fprintf(stderr, "Usage: %s server_query (-i -a | ) [options]\n\n" + "\tThe index must be passed with flags -i -a or with a file GRAPHS.csv listing one\n" + "\tor more indexes, a file with rows: ',,\\n'.\n" + "\t(If multiple rows have the same name, all those graphs will be queried for that name.)\n\n", prog_name.c_str()); fprintf(stderr, "Available options for server_query:\n"); fprintf(stderr, "\t --port [INT] \tTCP port for incoming connections [5555]\n"); @@ -1379,6 +1389,7 @@ if (advanced) { // fprintf(stderr, "\t-d --distance [INT] \tmax allowed alignment distance [0]\n"); fprintf(stderr, "\t-p --parallel [INT] \tmaximum number of parallel connections [1]\n"); // fprintf(stderr, "\t --cache-size [INT] \tnumber of uncompressed rows to store in the cache [0]\n"); + fprintf(stderr, "\n\t --num-top-labels [INT] \tmaximum number of top labels per query by default [10'000]\n"); } break; } diff --git a/metagraph/src/cli/load/load_annotation.cpp b/metagraph/src/cli/load/load_annotation.cpp index 4a682db6d7..6e5197f9d6 100644 --- a/metagraph/src/cli/load/load_annotation.cpp +++ b/metagraph/src/cli/load/load_annotation.cpp @@ -193,5 +193,22 @@ initialize_annotation(Config::AnnotationType anno_type, return annotation; } +// read labels from annotation +std::vector read_labels(const std::string &anno_fname) { + annot::LabelEncoder label_encoder; + + std::ifstream instream(anno_fname, std::ios::binary); + // TODO: make this cleaner + if (parse_annotation_type(anno_fname) == Config::ColumnCompressed) { + // Column compressed dumps the number of rows first + // skipping it... + load_number(instream); + } + if (!label_encoder.load(instream)) + throw std::ios_base::failure("Cannot read label encoder from file " + anno_fname); + + return label_encoder.get_labels(); +} + } // namespace cli } // namespace mtg diff --git a/metagraph/src/cli/load/load_annotation.hpp b/metagraph/src/cli/load/load_annotation.hpp index ea73b7dddc..9368fc5671 100644 --- a/metagraph/src/cli/load/load_annotation.hpp +++ b/metagraph/src/cli/load/load_annotation.hpp @@ -40,6 +40,9 @@ initialize_annotation(const std::string &filename, const Args &... args) { return initialize_annotation(parse_annotation_type(filename), args...); } +// read annotation labels from a serialized annotation +std::vector read_labels(const std::string &anno_fname); + } // namespace cli } // namespace mtg diff --git a/metagraph/src/cli/server.cpp b/metagraph/src/cli/server.cpp index f0f14f65ae..c27b03d622 100644 --- a/metagraph/src/cli/server.cpp +++ b/metagraph/src/cli/server.cpp @@ -1,4 +1,5 @@ #include +#include #include #include "common/logger.hpp" @@ -16,6 +17,7 @@ #include "query.hpp" #include "align.hpp" #include "server_utils.hpp" +#include "cli/load/load_annotation.hpp" namespace mtg { @@ -27,11 +29,9 @@ using namespace mtg::graph; using HttpServer = SimpleWeb::Server; -std::string process_search_request(const std::string &received_message, +Json::Value process_search_request(const Json::Value &json, const graph::AnnotatedDBG &anno_graph, const Config &config_orig) { - Json::Value json = parse_json_string(received_message); - const auto &fasta = json["FASTA"]; if (fasta.isNull()) throw std::domain_error("No input sequences received from client"); @@ -132,13 +132,11 @@ std::string process_search_request(const std::string &received_message, search_response.append(seq_result.to_json(config.verbose_output, anno_graph)); } - // Return JSON string - Json::StreamWriterBuilder builder; - return Json::writeString(builder, search_response); + return search_response; } // TODO: implement alignment_result.to_json as in process_search_request -std::string process_align_request(const std::string &received_message, +Json::Value process_align_request(const std::string &received_message, const graph::DeBruijnGraph &graph, const Config &config_orig) { Json::Value json = parse_json_string(received_message); @@ -205,48 +203,7 @@ std::string process_align_request(const std::string &received_message, root.append(align_entry); }); - Json::StreamWriterBuilder builder; - return Json::writeString(builder, root); -} - -std::string process_column_label_request(const graph::AnnotatedDBG &anno_graph) { - auto labels = anno_graph.get_annotator().get_label_encoder().get_labels(); - - Json::Value root = Json::Value(Json::arrayValue); - - for (const std::string &label : labels) { - Json::Value entry = label; - root.append(entry); - } - - Json::StreamWriterBuilder builder; - return Json::writeString(builder, root); -} - -std::string process_stats_request(const graph::AnnotatedDBG &anno_graph, - const std::string &graph_filename, - const std::string &annotation_filename) { - Json::Value root; - - Json::Value graph_stats; - graph_stats["filename"] = std::filesystem::path(graph_filename).filename().string(); - graph_stats["k"] = static_cast(anno_graph.get_graph().get_k()); - graph_stats["nodes"] = anno_graph.get_graph().num_nodes(); - graph_stats["is_canonical_mode"] = (anno_graph.get_graph().get_mode() - == graph::DeBruijnGraph::CANONICAL); - root["graph"] = graph_stats; - - Json::Value annotation_stats; - const auto &annotation = anno_graph.get_annotator(); - annotation_stats["filename"] = std::filesystem::path(annotation_filename).filename().string(); - annotation_stats["labels"] = static_cast(annotation.num_labels()); - annotation_stats["objects"] = static_cast(annotation.num_objects()); - annotation_stats["relations"] = static_cast(annotation.num_relations()); - - root["annotation"] = annotation_stats; - - Json::StreamWriterBuilder builder; - return Json::writeString(builder, root); + return root; } std::thread start_server(HttpServer &server_startup, Config &config) { @@ -257,8 +214,9 @@ std::thread start_server(HttpServer &server_startup, Config &config) { } server_startup.config.port = config.port; - logger->info("[Server] Will listen on {} port {}", config.host_address, - server_startup.config.port); + logger->info("[Server] Will listen on {} port {}", + server_startup.config.address, server_startup.config.port); + logger->info("[Server] Maximum connections: {}", get_num_threads()); return std::thread([&server_startup]() { server_startup.start(); }); } @@ -274,64 +232,256 @@ bool check_data_ready(std::shared_future &data, shared_ptr filter_graphs_from_list( + const tsl::hopscotch_map>> &indexes, + const Json::Value &content_json, + size_t request_id, + size_t max_names_without_filtering = 10) { + std::vector graphs_to_query; + + if (content_json.isMember("graphs") && content_json["graphs"].isArray()) { + for (const auto &item : content_json["graphs"]) { + graphs_to_query.push_back(item.asString()); + if (!indexes.count(graphs_to_query.back())) + throw std::invalid_argument("Request with an uninitialized graph " + graphs_to_query.back()); + } + // deduplicate + std::sort(graphs_to_query.begin(), graphs_to_query.end()); + graphs_to_query.erase(std::unique(graphs_to_query.begin(), graphs_to_query.end()), graphs_to_query.end()); + } else { + if (indexes.size() > max_names_without_filtering) { + throw std::invalid_argument( + fmt::format("Bad request: requests without names (no \"graphs\" field) are " + "only supported for small indexes (<={} names)", + max_names_without_filtering)); + } + // query all graphs from list `config->fnames` + for (const auto &[name, _] : indexes) { + graphs_to_query.push_back(name); + } + } + return graphs_to_query; +} + + int run_server(Config *config) { assert(config); - - assert(config->infbase_annotators.size() == 1); + std::atomic num_requests = 0; ThreadPool graph_loader(1, 1); + std::shared_future> anno_graph; - logger->info("[Server] Loading graph..."); + tsl::hopscotch_map>> indexes; - auto anno_graph = graph_loader.enqueue([&]() { - auto graph = load_critical_dbg(config->infbase); - logger->info("[Server] Graph loaded. Current mem usage: {} MiB", get_curr_RSS() >> 20); + if (config->infbase_annotators.size() == 1) { + assert(config->fnames.empty()); + anno_graph = graph_loader.enqueue([&]() { + logger->info("[Server] Loading graph..."); + auto graph = load_critical_dbg(config->infbase); + logger->info("[Server] Graph loaded. Current mem usage: {} MiB", get_curr_RSS() >> 20); - auto anno_graph = initialize_annotated_dbg(graph, *config); - logger->info("[Server] Annotated graph loaded too. Current mem usage: {} MiB", get_curr_RSS() >> 20); - return anno_graph; - }); + auto anno_graph = initialize_annotated_dbg(graph, *config); + logger->info("[Server] Annotated graph loaded too. Current mem usage: {} MiB", get_curr_RSS() >> 20); + return anno_graph; + }); + } else { + assert(config->fnames.size() == 1); - // defaults for the server - config->num_top_labels = 10000; + std::ifstream file(config->fnames[0]); + + if (!file.is_open()) { + logger->error("[Server] Could not open file {} for reading", config->fnames[0]); + std::exit(1); + } + + size_t num_indexes = 0; + std::string line; + while (std::getline(file, line)) { + if (line.empty()) + continue; // skip empty lines + + std::stringstream ss(line); + + std::string name; + std::string graph_path; + std::string annotation_path; + + std::getline(ss, name, ','); + std::getline(ss, graph_path, ','); + if (ss.eof()) { + logger->error("[Server] Invalid line in the csv file: `{}`", line); + std::exit(1); + } + std::getline(ss, annotation_path, ','); + + indexes[name].emplace_back(std::move(graph_path), std::move(annotation_path)); + num_indexes++; + } + std::vector names; + for (const auto &[name, _] : indexes) { + names.push_back(name); + } + logger->info("[Server] Loaded paths for {} graphs for {} names: {}", + num_indexes, indexes.size(), fmt::join(names, ", ")); + if (!utils::with_mmap()) { + logger->warn("[Server] --mmap wasn't passed but all indexes will be loaded with mmap." + " Make sure they're on a fast disk."); + utils::with_mmap(true); + } + } + + ThreadPool graphs_pool(get_num_threads()); + + logger->info("Collecting graph stats..."); + tsl::hopscotch_map> name_labels; + for (const auto &[name, graphs] : indexes) { + for (const auto &[graph_fname, anno_fname] : graphs) { + auto &out = name_labels[name]; + const auto &labels = read_labels(anno_fname); + out.insert(out.end(), labels.begin(), labels.end()); + } + } + + logger->info("All graphs were loaded and stats collected. Ready to serve queries."); // the actual server HttpServer server; server.resource["^/search"]["POST"] = [&](shared_ptr response, shared_ptr request) { - if (check_data_ready(anno_graph, response)) { - process_request(response, request, [&](const std::string &content) { - return process_search_request(content, *anno_graph.get(), *config); - }); - } + size_t request_id = num_requests++; + logger->info("[Server] {} request {} from {}", request->path, request_id, + request->remote_endpoint().address().to_string()); + + if (!config->fnames.size() && !check_data_ready(anno_graph, response)) + return; // the index is not loaded yet, so we can't process the request + + process_request(response, request, [&](const std::string &content) { + Json::Value content_json = parse_json_string(content); + logger->info("Request {}: {}", request_id, content_json.toStyledString()); + Json::Value result; + + // simple case with a single graph pair + if (!config->fnames.size()) { + if (content_json.isMember("graphs")) + throw std::invalid_argument("Bad request: no support for filtering graphs on this server"); + result = process_search_request(content_json, *anno_graph.get(), *config); + } else { + std::vector graphs_to_query + = filter_graphs_from_list(indexes, content_json, request_id); + std::mutex mu; + std::vector> futures; + for (const auto &name : graphs_to_query) { + for (const auto &[graph_fname, anno_fname] : indexes[name]) { + Config config_copy = *config; + config_copy.infbase = graph_fname; + config_copy.infbase_annotators = { anno_fname }; + futures.push_back(graphs_pool.enqueue([config_copy{std::move(config_copy)},&content_json,&result,&mu]() { + auto index = initialize_annotated_dbg(config_copy); + auto json = process_search_request(content_json, *index, config_copy); + std::lock_guard lock(mu); + if (result.empty()) { + result = std::move(json); + } else { + assert(json.size() == result.size()); + for (Json::ArrayIndex i = 0; i < result.size(); ++i) { + if (result[i][SeqSearchResult::SEQ_DESCRIPTION_JSON_FIELD] + != json[i][SeqSearchResult::SEQ_DESCRIPTION_JSON_FIELD]) { + throw std::logic_error("ERROR: Results for different sequences can't be merged"); + } + for (auto&& value : json[i]["results"]) { + result[i]["results"].append(std::move(value)); + } + } + } + })); + } + } + for (auto &future : futures) { + future.wait(); + } + } + logger->info("Request {} finished", request_id); + return result; + }); }; server.resource["^/align"]["POST"] = [&](shared_ptr response, shared_ptr request) { - if (check_data_ready(anno_graph, response)) { - process_request(response, request, [&](const std::string &content) { + size_t request_id = num_requests++; + logger->info("[Server] {} request {} from {}", request->path, request_id, + request->remote_endpoint().address().to_string()); + + if (!config->fnames.size() && !check_data_ready(anno_graph, response)) + return; // the index is not loaded yet, so we can't process the request + + process_request(response, request, [&](const std::string &content) { + if (!config->fnames.size()) return process_align_request(content, anno_graph.get()->get_graph(), *config); - }); - } + + throw std::invalid_argument("Bad request: alignment requests are not yet supported for " + "servers with multiple graphs"); + }); }; server.resource["^/column_labels"]["GET"] = [&](shared_ptr response, shared_ptr request) { - if (check_data_ready(anno_graph, response)) { - process_request(response, request, [&](const std::string &) { - return process_column_label_request(*anno_graph.get()); - }); - } + size_t request_id = num_requests++; + logger->info("[Server] {} request {} from {}", request->path, request_id, + request->remote_endpoint().address().to_string()); + + if (!config->fnames.size() && !check_data_ready(anno_graph, response)) + return; // the index is not loaded yet, so we can't process the request + + process_request(response, request, [&](const std::string &) { + Json::Value root(Json::arrayValue); + if (!config->fnames.size()) { + auto labels = anno_graph.get()->get_annotator().get_label_encoder().get_labels(); + for (const std::string &label : labels) { + root.append(label); + } + } else { + for (const auto &[name, labels] : name_labels) { + for (const std::string &label : labels) { + root.append(label); + } + } + } + return root; + }); }; server.resource["^/stats"]["GET"] = [&](shared_ptr response, shared_ptr request) { - if (check_data_ready(anno_graph, response)) { - process_request(response, request, [&](const std::string &) { - return process_stats_request(*anno_graph.get(), config->infbase, - config->infbase_annotators.front()); - }); - } + size_t request_id = num_requests++; + logger->info("[Server] {} request {} from {}", request->path, request_id, + request->remote_endpoint().address().to_string()); + + if (!config->fnames.size() && !check_data_ready(anno_graph, response)) + return; // the index is not loaded yet, so we can't process the request + + process_request(response, request, [&](const std::string &) { + Json::Value root; + if (config->fnames.size()) { + // for scenarios with multiple graphs + uint64_t num_labels = 0; + for (const auto &[name, labels] : name_labels) { + num_labels += labels.size(); + } + root["annotation"]["labels"] = num_labels; + } else { + root["graph"]["filename"] = std::filesystem::path(config->infbase).filename().string(); + root["graph"]["k"] = static_cast(anno_graph.get()->get_graph().get_k()); + root["graph"]["nodes"] = anno_graph.get()->get_graph().num_nodes(); + root["graph"]["is_canonical_mode"] = (anno_graph.get()->get_graph().get_mode() + == graph::DeBruijnGraph::CANONICAL); + const auto &annotation = anno_graph.get()->get_annotator(); + root["annotation"]["filename"] = std::filesystem::path(config->infbase_annotators.front()).filename().string(); + root["annotation"]["labels"] = static_cast(annotation.num_labels()); + root["annotation"]["objects"] = static_cast(annotation.num_objects()); + root["annotation"]["relations"] = static_cast(annotation.num_relations()); + } + return root; + }); }; server.default_resource["GET"] = [](shared_ptr response, diff --git a/metagraph/src/cli/server_utils.cpp b/metagraph/src/cli/server_utils.cpp index 62aa59dd65..dac6547535 100644 --- a/metagraph/src/cli/server_utils.cpp +++ b/metagraph/src/cli/server_utils.cpp @@ -99,14 +99,14 @@ std::string json_str_with_error_msg(const std::string &msg) { void process_request(std::shared_ptr &response, const std::shared_ptr &request, - const std::function &process) { + const std::function &process) { // Retrieve string: std::string content = request->content.string(); - logger->info("[Server] {} request from {}", request->path, - request->remote_endpoint().address().to_string()); try { - std::string ret = process(content); + // Return JSON string + Json::StreamWriterBuilder builder; + std::string ret = Json::writeString(builder, process(content)); write_response(SimpleWeb::StatusCode::success_ok, ret, response, is_compression_requested(request)); } catch (const std::exception &e) { diff --git a/metagraph/src/cli/server_utils.hpp b/metagraph/src/cli/server_utils.hpp index 20b6ffdb03..925dd0fc94 100644 --- a/metagraph/src/cli/server_utils.hpp +++ b/metagraph/src/cli/server_utils.hpp @@ -11,7 +11,7 @@ using HttpServer = SimpleWeb::Server; void process_request(std::shared_ptr &response, const std::shared_ptr &request, - const std::function &process); + const std::function &process); Json::Value parse_json_string(const std::string &msg); diff --git a/metagraph/src/cli/stats.cpp b/metagraph/src/cli/stats.cpp index c29ef07c09..7051ca0002 100644 --- a/metagraph/src/cli/stats.cpp +++ b/metagraph/src/cli/stats.cpp @@ -187,30 +187,17 @@ void print_annotation_stats(const std::string &fname, const Config &config) { logger->info("Scanning annotation '{}'", fname); try { - std::ifstream instream(fname, std::ios::binary); - - // TODO: make this more reliable - if (parse_annotation_type(fname) == Config::ColumnCompressed) { - // Column compressed dumps the number of rows first - // skipping it... - load_number(instream); + auto labels = read_labels(fname); + std::cout << "Number of columns: " << labels.size() << '\n'; + for (const auto &label : labels) { + std::cout << label << '\n'; } - - if (!label_encoder.load(instream)) - throw std::ios_base::failure(""); - + std::cout << std::flush; + return; } catch (...) { logger->error("Cannot read label encoder from file '{}'", fname); exit(1); } - - std::cout << "Number of columns: " << label_encoder.size() << '\n'; - for (size_t c = 0; c < label_encoder.size(); ++c) { - std::cout << label_encoder.decode(c) << '\n'; - } - - std::cout << std::flush; - return; } logger->info("Statistics for annotation '{}'", fname);