Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions metagraph/api/python/metagraph/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def search(self, sequence: Union[str, Iterable[str]],
abundance_sum: bool = False,
query_counts: bool = False,
query_coords: bool = False,
graphs: Union[None, List[str]] = None,
align: bool = False,
**align_params) -> Tuple[JsonDict, str]:
"""See parameters for alignment `align_params` in align()"""
Expand Down Expand Up @@ -80,6 +81,8 @@ def search(self, sequence: Union[str, Iterable[str]],
"abundance_sum": abundance_sum,
"query_counts": query_counts,
"query_coords": query_coords}
if graphs is not None:
param_dict["graphs"] = graphs

search_results = self._json_seq_query(sequence, param_dict, "search")

Expand Down Expand Up @@ -178,6 +181,7 @@ def search(self, sequence: Union[str, Iterable[str]],
abundance_sum: bool = False,
query_counts: bool = False,
query_coords: bool = False,
graphs: Union[None, List[str]] = None,
align: bool = False,
**align_params) -> pd.DataFrame:
"""
Expand All @@ -195,6 +199,8 @@ def search(self, sequence: Union[str, Iterable[str]],
:type query_counts: bool
:param query_coords: Query k-mer coordinates
:type query_coords: bool
:param graphs: List of graph names to search. If None, search all graphs
:type graphs: Union[None, List[str]]
:param align: Align the query sequence to the joint graph and query labels for that alignment instead of the original sequence
:type align: bool
:param align_params: The parameters for alignment (see method align())
Expand All @@ -207,7 +213,7 @@ def search(self, sequence: Union[str, Iterable[str]],
json_obj = self._json_client.search(sequence, top_labels,
discovery_fraction, with_signature,
abundance_sum, query_counts, query_coords,
align, **align_params)
graphs, align, **align_params)

return helpers.df_from_search_result(json_obj)

Expand Down Expand Up @@ -272,6 +278,7 @@ def search(self, sequence: Union[str, Iterable[str]],
abundance_sum: bool = False,
query_counts: bool = False,
query_coords: bool = False,
graphs: Union[None, List[str]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bit confusing compared to self.graphs in this class. Are we sure we want to keep this naming? Maybe rename the parameter graphs in API to labels, or similar?

That aside, how are MultiGraphClients actually used? Do we theoretically expect a situation in which e.g. first graph serves label A while the second graph serves labels B and C, and we want to query just label A or just labels A and B?

Judging by the unit test below, this would fail. Is it the outcome we want, rather than e.g. returning empty output (and possibly some kind of warning in logs) when querying a label that we don't know?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, true. We can try to come up with a different name. Otherwise, MultiGraphClient is essentially a GraphClient with a ThreadPool, and it's not used all that migh right now. Maybe it's simple enough to remove. Anyone can always add a ThreadPool on top any time

align: bool = False,
**align_params) -> Dict[str, Union[pd.DataFrame, Future]]:
"""
Expand All @@ -290,7 +297,7 @@ def search(self, sequence: Union[str, Iterable[str]],
result[name] = graph_client.search(sequence, top_labels,
discovery_fraction, with_signature,
abundance_sum, query_counts, query_coords,
align, **align_params)
graphs, align, **align_params)

return result

Expand All @@ -305,7 +312,7 @@ def search(self, sequence: Union[str, Iterable[str]],
futures[name] = executor.submit(graph_client.search, sequence,
top_labels, discovery_fraction, with_signature,
abundance_sum, query_counts, query_coords,
align, **align_params)
graphs, align, **align_params)

print(f'Made {len(self.graphs)} requests with {num_processes} threads...')

Expand Down
192 changes: 176 additions & 16 deletions metagraph/integration_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

import pandas as pd

from metagraph.client import GraphClientJson, MultiGraphClient
from metagraph.client import GraphClientJson, MultiGraphClient, GraphClient
from concurrent.futures import Future
from parameterized import parameterized, parameterized_class
from itertools import product

from base import PROTEIN_MODE, TestingBase, METAGRAPH, TEST_DATA_DIR

GRAPH_MODES = ['basic'] if PROTEIN_MODE else ['basic', 'canonical', 'primary']


class TestAPIBase(TestingBase):
@classmethod
def setUpClass(cls, fasta_path, mode='basic', anno_repr='column'):
Expand Down Expand Up @@ -53,14 +57,13 @@ def tearDownClass(cls):

def _start_server(self, graph, annotation):
construct_command = f'{METAGRAPH} server_query -i {graph} -a {annotation} \
--port {self.port} --address {self.host} -p {2}'
--port {self.port} --address {self.host} -p 2'

return subprocess.Popen(shlex.split(construct_command))


# No canonical mode for Protein alphabets
@parameterized_class(('mode',),
input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)]))
@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES])
class TestAPIRaw(TestAPIBase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -249,8 +252,7 @@ def test_api_raw_search_no_count_support(self):


# No canonical mode for Protein alphabets
@parameterized_class(('mode',),
input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)]))
@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES])
class TestAPIClient(TestAPIBase):
graph_name = 'test_graph'

Expand Down Expand Up @@ -356,10 +358,172 @@ def test_api_search_no_count_support(self):
ret = self.graph_client.search(self.sample_query, parallel=False,
discovery_fraction=0.01, abundance_sum=True)

@unittest.expectedFailure
def test_api_search_bad_graphs(self):
df = self.graph_client.search(self.sample_query, parallel=False, graphs=['X'])


# No canonical mode for Protein alphabets
@parameterized_class(('mode', 'threads_each'),
input_values=[(mode, 1) for mode in GRAPH_MODES] + [('basic', 4)])
class TestAPIClientMultiple(TestingBase):
@classmethod
def setUpClass(
cls,
fasta_paths=[
TEST_DATA_DIR + '/transcripts_100.fa',
TEST_DATA_DIR + '/transcripts_1000.fa'
],
anno_repr='column',
):
super().setUpClass()

graph_paths = [
cls.tempdir.name + '/graph_1.dbg',
cls.tempdir.name + '/graph_2.dbg',
]
annotation_path_bases = [
cls.tempdir.name + '/annotation_1',
cls.tempdir.name + '/annotation_2',
]

cls._build_graph(fasta_paths[0], graph_paths[0], 6, 'succinct', mode=cls.mode)
cls._annotate_graph(fasta_paths[0], graph_paths[0], annotation_path_bases[0], anno_repr)

cls._build_graph(fasta_paths[1], graph_paths[1], 6, 'succinct', mode=cls.mode)
cls._annotate_graph(fasta_paths[1], graph_paths[1], annotation_path_bases[1], anno_repr)

cls.mult = 1 if cls.threads_each < 2 else 200 # duplicate graphs so that there are more to query

graphs_file = cls.tempdir.name + '/graphs.txt'
with open(graphs_file, 'w') as f:
f.write('\n'.join([
f'G1,{graph_paths[0]},{annotation_path_bases[0]}.{anno_repr}.annodbg',
f'G2,{graph_paths[0]},{annotation_path_bases[0]}.{anno_repr}.annodbg',
f'G2,{graph_paths[0]},{annotation_path_bases[0]}.{anno_repr}.annodbg',
f'G3,{graph_paths[1]},{annotation_path_bases[1]}.{anno_repr}.annodbg',
] * cls.mult))

cls.host = '127.0.0.1'
os.environ['NO_PROXY'] = cls.host
cls.port = 3456
num_retries = 100
while num_retries > 0:
cls.server_process = subprocess.Popen(shlex.split(
f'{METAGRAPH} server_query {graphs_file} \
--port {cls.port} --address {cls.host} -p 2 --threads-each {cls.threads_each}'
))
try:
cls.server_process.wait(timeout=1)
except subprocess.TimeoutExpired:
break
cls.port += 1
num_retries -= 1
if num_retries == 0:
raise "Couldn't start server"

wait_time_sec = 1
print("Waiting {} sec for the server (PID {}) to start up".format(
wait_time_sec, cls.server_process.pid), flush=True)
time.sleep(wait_time_sec)

cls.graph_client = MultiGraphClient()
cls.graph_name = 'G1,G2x2,G3'
cls.graph_client.add_graph(cls.host, cls.port, cls.graph_name)
cls.client = GraphClient(cls.host, cls.port, cls.graph_name)

cls.sample_query = 'CCTCTGTGGAATCCAATCTGTCTTCCATCCTGCGTGGCCGAGGG'
# 'canonical' and 'primary' graphs represent more k-mers than 'basic', so
# they get more matches
cls.expected_rows_1 = (98 if cls.mode == 'basic' else 99) * cls.mult
cls.expected_matches_1 = (840 if cls.mode == 'basic' else 1381) * cls.mult

cls.expected_rows_2 = cls.expected_rows_1 * 2
cls.expected_matches_2 = cls.expected_matches_1 * 2

cls.expected_rows_3 = (100) * cls.mult
cls.expected_matches_3 = (2843 if cls.mode == 'basic' else 3496) * cls.mult

@classmethod
def tearDownClass(cls):
cls.server_process.kill()

def test_api_query_df_multiple_1(self):
ret = self.graph_client.search(self.sample_query, parallel=False,
discovery_threshold=0.01, graphs=['G1'])
df = ret[self.graph_name]
self.assertEqual((self.expected_rows_1, 3), df.shape)
self.assertEqual(df['kmer_count'].sum(), self.expected_matches_1)

def test_api_query_df_multiple_2(self):
ret = self.graph_client.search(self.sample_query, parallel=False,
discovery_threshold=0.01, graphs=['G2'])
df = ret[self.graph_name]
self.assertEqual((self.expected_rows_2, 3), df.shape)
self.assertEqual(df['kmer_count'].sum(), self.expected_matches_2)

def test_api_query_df_multiple_3(self):
ret = self.graph_client.search(self.sample_query, parallel=False,
discovery_threshold=0.01, graphs=['G3'])
df = ret[self.graph_name]
self.assertEqual((self.expected_rows_3, 3), df.shape)
self.assertEqual(df['kmer_count'].sum(), self.expected_matches_3)

def test_api_query_df_multiple_12(self):
ret = self.graph_client.search(self.sample_query, parallel=False,
discovery_threshold=0.01, graphs=['G1', 'G2'])
df = ret[self.graph_name]
self.assertEqual((self.expected_rows_1 + self.expected_rows_2, 3), df.shape)
self.assertEqual(df['kmer_count'].sum(),
self.expected_matches_1 + self.expected_matches_2)

def test_api_query_df_multiple_all(self):
ret = self.graph_client.search(self.sample_query, parallel=False,
discovery_threshold=0.01, graphs=['G1', 'G2', 'G3'])
df = ret[self.graph_name]
self.assertEqual((self.expected_rows_1 + self.expected_rows_2 + self.expected_rows_3, 3), df.shape)
self.assertEqual(df['kmer_count'].sum(),
self.expected_matches_1 + self.expected_matches_2 + self.expected_matches_3)

def test_api_query_df_multiple(self):
ret = self.graph_client.search(self.sample_query, parallel=False,
discovery_threshold=0.01)
df = ret[self.graph_name]
self.assertEqual((self.expected_rows_1 + self.expected_rows_2 + self.expected_rows_3, 3), df.shape)
self.assertEqual(df['kmer_count'].sum(),
self.expected_matches_1 + self.expected_matches_2 + self.expected_matches_3)

@unittest.expectedFailure
def test_api_query_df_multiple_bad(self):
df = self.graph_client.search(self.sample_query, parallel=False, graphs=['G1', 'X'])

def test_api_stats_multiple_graphs(self):
"""Test /stats endpoint with multiple graphs returns aggregated stats"""
ret = self.client._json_client.stats()

# Should have annotation section with aggregated labels
self.assertIn("annotation", ret.keys())
self.assertIn("labels", ret["annotation"])

# Total labels should be sum of all graphs
self.assertEqual(ret["annotation"]["labels"], 1300 * self.mult)

def test_api_column_labels_multiple_graphs(self):
"""Test /column_labels endpoint with multiple graphs returns all labels"""
ret = self.graph_client.column_labels()

self.assertIn(self.graph_name, ret.keys())
label_list = ret[self.graph_name]
# Should have labels from all graphs combined
self.assertEqual(len(label_list), 1300 * self.mult)
# All labels should start with 'ENST' (from the test data)
self.assertTrue(all(l.startswith('ENST') for l in label_list))
# Check that the total number of deduplicated labels equals to 1000
self.assertEqual(len(set(label_list)), 1000)


# No canonical mode for Protein alphabets
@parameterized_class(('mode',),
input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)]))
@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES])
class TestAPIJson(TestAPIBase):
graph_name = 'test_graph'

Expand Down Expand Up @@ -417,8 +581,7 @@ def test_api_stats(self):


# No canonical mode for Protein alphabets
@parameterized_class(('mode',),
input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)]))
@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES])
class TestAPIClientWithProperties(TestAPIBase):
"""
Testing whether properties encoded in sample name are properly processed
Expand Down Expand Up @@ -449,8 +612,7 @@ def test_api_search_property_df_empty(self):


# No canonical mode for Protein alphabets
@parameterized_class(('mode',),
input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)]))
@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES])
class TestAPIClientWithCoordinates(TestAPIBase):
"""
Testing whether API works well given coordinate aware annotations
Expand Down Expand Up @@ -509,8 +671,7 @@ def test_api_simple_query_coords_df(self):


# No canonical mode for Protein alphabets
@parameterized_class(('mode',),
input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)]))
@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES])
class TestAPIClientWithCounts(TestAPIBase):
"""
Testing whether API works well given k-mer count aware annotations
Expand Down Expand Up @@ -564,8 +725,7 @@ def test_api_search_no_coordinate_support(self):


# No canonical mode for Protein alphabets
@parameterized_class(('mode',),
input_values=[('basic',)] + ([] if PROTEIN_MODE else [('canonical',), ('primary',)]))
@parameterized_class(('mode',), input_values=[(mode,) for mode in GRAPH_MODES])
class TestAPIClientParallel(TestAPIBase):
"""
Testing whether or not parallel requests work
Expand Down
17 changes: 14 additions & 3 deletions metagraph/src/cli/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Config::Config(int argc, char *argv[]) {
identity = QUERY;
} else if (!strcmp(argv[1], "server_query")) {
identity = SERVER_QUERY;
num_top_labels = 10'000;
} else if (!strcmp(argv[1], "transform")) {
identity = TRANSFORM;
} else if (!strcmp(argv[1], "transform_anno")) {
Expand Down Expand Up @@ -543,7 +544,7 @@ Config::Config(int argc, char *argv[]) {
if (count_kmers || query_presence)
map_sequences = true;

if ((identity == QUERY || identity == SERVER_QUERY) && infbase.empty())
if (identity == QUERY && infbase.empty())
print_usage_and_exit = true;

if ((identity == QUERY || identity == SERVER_QUERY || identity == ALIGN)
Expand Down Expand Up @@ -575,9 +576,15 @@ Config::Config(int argc, char *argv[]) {
if (identity == EXTEND && infbase.empty())
print_usage_and_exit = true;

if ((identity == QUERY || identity == SERVER_QUERY) && infbase_annotators.size() != 1)
if (identity == QUERY && infbase_annotators.size() != 1)
print_usage_and_exit = true;

if (identity == SERVER_QUERY
&& (fnames.size() > 1
|| (fnames.size() && (infbase.size() || infbase_annotators.size()))
|| (fnames.empty() && (infbase.empty() || infbase_annotators.size() != 1))))
print_usage_and_exit = true; // only one of fnames or (infbase & annotator) must be used

if ((identity == TRANSFORM
|| identity == CLEAN
|| identity == ASSEMBLE
Expand Down Expand Up @@ -1369,7 +1376,10 @@ if (advanced) {
}
} break;
case SERVER_QUERY: {
fprintf(stderr, "Usage: %s server_query -i <GRAPH> -a <ANNOTATION> [options]\n\n", prog_name.c_str());
fprintf(stderr, "Usage: %s server_query (-i <GRAPH> -a <ANNOTATION> | <GRAPHS.csv>) [options]\n\n"
"\tThe index must be passed with flags -i -a or with a file GRAPHS.csv listing one\n"
"\tor more indexes, a file with rows: '<name>,<graph_path>,<annotation_path>\\n'.\n"
"\t(If multiple rows have the same name, all those graphs will be queried for that name.)\n\n", prog_name.c_str());

fprintf(stderr, "Available options for server_query:\n");
fprintf(stderr, "\t --port [INT] \tTCP port for incoming connections [5555]\n");
Expand All @@ -1379,6 +1389,7 @@ if (advanced) {
// fprintf(stderr, "\t-d --distance [INT] \tmax allowed alignment distance [0]\n");
fprintf(stderr, "\t-p --parallel [INT] \tmaximum number of parallel connections [1]\n");
// fprintf(stderr, "\t --cache-size [INT] \tnumber of uncompressed rows to store in the cache [0]\n");
fprintf(stderr, "\n\t --num-top-labels [INT] \tmaximum number of top labels per query by default [10'000]\n");
} break;
}

Expand Down
Loading
Loading