diff --git a/graphdatascience/arrow_client/v2/mutation_client.py b/graphdatascience/arrow_client/v2/mutation_client.py index 44bde5c5e..a02d8e61e 100644 --- a/graphdatascience/arrow_client/v2/mutation_client.py +++ b/graphdatascience/arrow_client/v2/mutation_client.py @@ -11,7 +11,41 @@ class MutationClient: @staticmethod def mutate_node_property(client: AuthenticatedArrowClient, job_id: str, mutate_property: str) -> MutateResult: - mutate_config = {"jobId": job_id, "mutateProperty": mutate_property} + return MutationClient._mutate( + client=client, + job_id=job_id, + mutate_property=mutate_property, + ) + + @staticmethod + def mutate_relationship_property( + client: AuthenticatedArrowClient, + job_id: str, + mutate_relationship_type: str, + mutate_property: str | None, + ) -> MutateResult: + return MutationClient._mutate( + client=client, + job_id=job_id, + mutate_property=mutate_property, + mutate_relationship_type=mutate_relationship_type, + ) + + @staticmethod + def _mutate( + client: AuthenticatedArrowClient, + job_id: str, + mutate_property: str | None = None, + mutate_relationship_type: str | None = None, + ) -> MutateResult: + mutate_config = { + "jobId": job_id, + } + if mutate_relationship_type: + mutate_config["mutateRelationshipType"] = mutate_relationship_type + if mutate_property: + mutate_config["mutateProperty"] = mutate_property + start_time = time.time() mutate_arrow_res = client.do_action_with_retry(MutationClient.MUTATE_ENDPOINT, mutate_config) mutate_millis = math.ceil((time.time() - start_time) * 1000) diff --git a/graphdatascience/procedure_surface/api/community/local_clustering_coefficient_endpoints.py b/graphdatascience/procedure_surface/api/community/local_clustering_coefficient_endpoints.py index ff1f3ab6f..efe07615d 100644 --- a/graphdatascience/procedure_surface/api/community/local_clustering_coefficient_endpoints.py +++ b/graphdatascience/procedure_surface/api/community/local_clustering_coefficient_endpoints.py @@ -166,7 +166,6 @@ def write( triangle_count_property: str | None = None, username: str | None = None, write_concurrency: int | None = None, - write_to_result_store: bool | None = None, ) -> "LocalClusteringCoefficientWriteResult": """ Executes the LocalClusteringCoefficient algorithm and writes results to the database. @@ -195,8 +194,6 @@ def write( Username for authentication write_concurrency : int | None, default=None Concurrency for writing back to the database - write_to_result_store : bool | None, default=None - Whether to write to the result store Returns ------- diff --git a/graphdatascience/procedure_surface/api/community/modularity_optimization_endpoints.py b/graphdatascience/procedure_surface/api/community/modularity_optimization_endpoints.py index e61a2c3d1..b0cbe2055 100644 --- a/graphdatascience/procedure_surface/api/community/modularity_optimization_endpoints.py +++ b/graphdatascience/procedure_surface/api/community/modularity_optimization_endpoints.py @@ -218,7 +218,6 @@ def write( tolerance: float | None = None, username: str | None = None, write_concurrency: int | None = None, - write_to_result_store: bool | None = None, ) -> ModularityOptimizationWriteResult: """ Executes the Modularity Optimization algorithm and writes the results back to the database. @@ -259,8 +258,6 @@ def write( Username for authentication write_concurrency : int | None, default=None The number of concurrent threads for writing - write_to_result_store : bool | None, default=None - Whether to write results to the result store Returns ------- diff --git a/graphdatascience/procedure_surface/api/similarity/__init__.py b/graphdatascience/procedure_surface/api/similarity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/api/similarity/knn_endpoints.py b/graphdatascience/procedure_surface/api/similarity/knn_endpoints.py new file mode 100644 index 000000000..af97c5609 --- /dev/null +++ b/graphdatascience/procedure_surface/api/similarity/knn_endpoints.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod, abstractproperty +from typing import Any + +from pandas import DataFrame + +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints +from graphdatascience.procedure_surface.api.similarity.knn_results import ( + KnnMutateResult, + KnnStatsResult, + KnnWriteResult, +) + + +class KnnEndpoints(ABC): + @abstractproperty + def filtered(self) -> KnnFilteredEndpoints: + pass + + @abstractmethod + def mutate( + self, + G: GraphV2, + mutate_relationship_type: str, + mutate_property: str, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnMutateResult: + """ + Runs the K-Nearest Neighbors algorithm and stores the results as new relationships in the graph catalog. + + Parameters + ---------- + G : GraphV2 + The graph to run the algorithm on + mutate_relationship_type : str + The relationship type to use for the new relationships. + mutate_property : str + The relationship property to store the similarity score in. + node_properties : str | list[str] | dict[str, str], + The node properties to use for similarity computation. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + sudo : bool, default=False + Run the algorithm with elevated privileges. + log_progress : bool, default=True + Whether to log progress. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + job_id : str | None, default=None + Job ID for the operation. + + Returns + ------- + KnnMutateResult + Object containing metadata from the execution. + """ + + @abstractmethod + def stats( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnStatsResult: + """ + Runs the K-Nearest Neighbors algorithm and returns execution statistics. + + Parameters + ---------- + G : GraphV2 + The graph to run the algorithm on + node_properties: str | list[str] | dict[str, str], + The node properties to use for similarity computation. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + sudo : bool, default=False + Run the algorithm with elevated privileges. + log_progress : bool, default=True + Whether to log progress. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + job_id : str | None, default=None + Job ID for the operation. + + Returns + ------- + KnnStatsResult + Object containing execution statistics and algorithm-specific results. + """ + + @abstractmethod + def stream( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> DataFrame: + """ + Runs the K-Nearest Neighbors algorithm and returns the result as a DataFrame. + + Parameters + ---------- + G : GraphV2 + The graph to run the algorithm on + node_properties: str | list[str] | dict[str, str], + The node properties to use for similarity computation. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + sudo : bool, default=False + Run the algorithm with elevated privileges. + log_progress : bool, default=True + Whether to log progress. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + job_id : str | None, default=None + Job ID for the operation. + + Returns + ------- + DataFrame + The similarity results as a DataFrame with columns 'node1', 'node2', and 'similarity'. + """ + + @abstractmethod + def write( + self, + G: GraphV2, + write_relationship_type: str, + write_property: str, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + write_concurrency: int | None = None, + ) -> KnnWriteResult: + """ + Runs the K-Nearest Neighbors algorithm and writes the results back to the database. + + Parameters + ---------- + G : GraphV2 + The graph to run the algorithm on + write_relationship_type : str + The relationship type to use for the new relationships. + write_property : str + The relationship property to store the similarity score in. + node_properties: str | list[str] | dict[str, str], + The node properties to use for similarity computation. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + sudo : bool, default=False + Run the algorithm with elevated privileges. + log_progress : bool, default=True + Whether to log progress. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + job_id : str | None, default=None + Job ID for the operation. + write_concurrency : int | None, default=None + Concurrency for writing results. + + Returns + ------- + KnnWriteResult + Object containing metadata from the execution. + """ + + @abstractmethod + def estimate( + self, + G: GraphV2 | dict[str, Any], + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + ) -> EstimationResult: + """ + Estimates the memory requirements for running the K-Nearest Neighbors algorithm. + + Parameters + ---------- + G : GraphV2 | dict[str, Any] + The graph to run the algorithm on + node_properties: str | list[str] | dict[str, str], + The node properties to use for similarity computation. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + sudo : bool, default=False + Run the algorithm with elevated privileges. + log_progress : bool, default=True + Whether to log progress. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + job_id : str | None, default=None + Job ID for the operation. + + Returns + ------- + EstimationResult + Object containing the estimated memory requirements. + """ diff --git a/graphdatascience/procedure_surface/api/similarity/knn_filtered_endpoints.py b/graphdatascience/procedure_surface/api/similarity/knn_filtered_endpoints.py new file mode 100644 index 000000000..187ce4a72 --- /dev/null +++ b/graphdatascience/procedure_surface/api/similarity/knn_filtered_endpoints.py @@ -0,0 +1,435 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +from pandas import DataFrame + +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.similarity.knn_results import ( + KnnMutateResult, + KnnStatsResult, + KnnWriteResult, +) + + +class KnnFilteredEndpoints(ABC): + @abstractmethod + def mutate( + self, + G: GraphV2, + mutate_relationship_type: str, + mutate_property: str, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnMutateResult: + """ + Runs the Filtered K-Nearest Neighbors algorithm and stores the results as new relationships in the graph catalog. + + The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset. + + Parameters + ---------- + G : GraphV2 + The graph to run the algorithm on + mutate_relationship_type : str + The relationship type to use for the new relationships. + mutate_property : str + The relationship property to store the similarity score in. + node_properties : str | list[str] | dict[str, str] + The node properties to use for similarity computation. + source_node_filter : str + A Cypher expression to filter which nodes can be sources in the similarity computation. + target_node_filter : str + A Cypher expression to filter which nodes can be targets in the similarity computation. + seed_target_nodes : bool | None, default=None + Whether to use a seeded approach for target node selection. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + sudo : bool, default=False + Run the algorithm with elevated privileges. + log_progress : bool, default=True + Whether to log progress. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + job_id : str | None, default=None + Job ID for the operation. + + Returns + ------- + KnnMutateResult + Object containing metadata from the execution. + """ + ... + + @abstractmethod + def stats( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnStatsResult: + """ + Runs the Filtered K-Nearest Neighbors algorithm and returns execution statistics. + + The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset. + + Parameters + ---------- + G : GraphV2 + The graph to run the algorithm on + node_properties : str | list[str] | dict[str, str] + The node properties to use for similarity computation. + source_node_filter : str + A Cypher expression to filter which nodes can be sources in the similarity computation. + target_node_filter : str + A Cypher expression to filter which nodes can be targets in the similarity computation. + seed_target_nodes : bool | None, default=None + Whether to use a seeded approach for target node selection. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + sudo : bool, default=False + Run the algorithm with elevated privileges. + log_progress : bool, default=True + Whether to log progress. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + job_id : str | None, default=None + Job ID for the operation. + + Returns + ------- + KnnStatsResult + Object containing execution statistics and algorithm-specific results. + """ + ... + + @abstractmethod + def stream( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> DataFrame: + """ + Runs the Filtered K-Nearest Neighbors algorithm and returns the result as a DataFrame. + + The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset. + + Parameters + ---------- + G : GraphV2 + The graph to run the algorithm on + node_properties : str | list[str] | dict[str, str] + The node properties to use for similarity computation. + source_node_filter : str + A Cypher expression to filter which nodes can be sources in the similarity computation. + target_node_filter : str + A Cypher expression to filter which nodes can be targets in the similarity computation. + seed_target_nodes : bool | None, default=None + Whether to use a seeded approach for target node selection. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + sudo : bool, default=False + Run the algorithm with elevated privileges. + log_progress : bool, default=True + Whether to log progress. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + job_id : str | None, default=None + Job ID for the operation. + + Returns + ------- + DataFrame + The similarity results as a DataFrame with columns 'node1', 'node2', and 'similarity'. + """ + ... + + @abstractmethod + def write( + self, + G: GraphV2, + write_relationship_type: str, + write_property: str, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + write_concurrency: int | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnWriteResult: + """ + Runs the Filtered K-Nearest Neighbors algorithm and writes the results back to the database. + + The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset. + + Parameters + ---------- + G : GraphV2 + The graph to run the algorithm on + write_relationship_type : str + The relationship type to use for the new relationships. + write_property : str + The relationship property to store the similarity score in. + node_properties : str | list[str] | dict[str, str] + The node properties to use for similarity computation. + source_node_filter : str + A Cypher expression to filter which nodes can be sources in the similarity computation. + target_node_filter : str + A Cypher expression to filter which nodes can be targets in the similarity computation. + seed_target_nodes : bool | None, default=None + Whether to use a seeded approach for target node selection. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + write_concurrency : int | None, default=None + Concurrency for writing results. + sudo : bool, default=False + Run the algorithm with elevated privileges. + log_progress : bool, default=True + Whether to log progress. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + job_id : str | None, default=None + Job ID for the operation. + + Returns + ------- + KnnWriteResult + Object containing metadata from the execution. + """ + ... + + @abstractmethod + def estimate( + self, + G: GraphV2 | dict[str, Any], + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + username: str | None = None, + concurrency: int | None = None, + ) -> EstimationResult: + """ + Estimates the memory requirements for running the Filtered K-Nearest Neighbors algorithm. + + The Filtered K-Nearest Neighbors algorithm computes a distance value for node pairs in the graph with customizable source and target node filters, creating new relationships between each node and its k nearest neighbors within the filtered subset. + + Parameters + ---------- + G : GraphV2 | dict[str, Any] + The graph to run the algorithm on. + node_properties : str | list[str] | dict[str, str] + The node properties to use for similarity computation. + source_node_filter : str + A Cypher expression to filter which nodes can be sources in the similarity computation. + target_node_filter : str + A Cypher expression to filter which nodes can be targets in the similarity computation. + seed_target_nodes : bool | None, default=None + Whether to use a seeded approach for target node selection. + top_k : int, default=10 + The number of nearest neighbors to find for each node. + similarity_cutoff : float, default=0.0 + The threshold for similarity scores. + delta_threshold : float, default=0.001 + The threshold for convergence assessment. + max_iterations : int, default=100 + The maximum number of iterations to run. + sample_rate : float, default=0.5 + The sampling rate for the algorithm. + perturbation_rate : float, default=0.0 + The rate at which to perturb the similarity graph. + random_joins : int, default=10 + The number of random joins to perform. + random_seed : int | None, default=None + The seed for the random number generator. + initial_sampler : str, default="UNIFORM" + The initial sampling strategy. + relationship_types : list[str] | None, default=None + Filter on relationship types. + node_labels : list[str] | None, default=None + Filter on node labels. + sudo : bool, default=False + Run the algorithm with elevated privileges. + username : str | None, default=None + Username for the operation. + concurrency : int | None, default=None + Concurrency configuration. + + Returns + ------- + EstimationResult + Object containing the estimated memory requirements. + """ + ... diff --git a/graphdatascience/procedure_surface/api/similarity/knn_results.py b/graphdatascience/procedure_surface/api/similarity/knn_results.py new file mode 100644 index 000000000..3d8ac36d0 --- /dev/null +++ b/graphdatascience/procedure_surface/api/similarity/knn_results.py @@ -0,0 +1,44 @@ +from typing import Any + +from graphdatascience.procedure_surface.api.base_result import BaseResult + + +class KnnMutateResult(BaseResult): + pre_processing_millis: int + compute_millis: int + mutate_millis: int + post_processing_millis: int + nodes_compared: int + relationships_written: int + similarity_distribution: dict[str, int | float] + did_converge: bool + ran_iterations: int + node_pairs_considered: int + configuration: dict[str, Any] + + +class KnnStatsResult(BaseResult): + pre_processing_millis: int + compute_millis: int + post_processing_millis: int + nodes_compared: int + similarity_pairs: int + similarity_distribution: dict[str, int | float] + did_converge: bool + ran_iterations: int + node_pairs_considered: int + configuration: dict[str, Any] + + +class KnnWriteResult(BaseResult): + pre_processing_millis: int + compute_millis: int + write_millis: int + post_processing_millis: int + nodes_compared: int + relationships_written: int + did_converge: bool + ran_iterations: int + node_pairs_considered: int + similarity_distribution: dict[str, int | float] + configuration: dict[str, Any] diff --git a/graphdatascience/procedure_surface/arrow/catalog/node_label_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/catalog/node_label_arrow_endpoints.py index 3c31d4126..d7dd8a537 100644 --- a/graphdatascience/procedure_surface/arrow/catalog/node_label_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/catalog/node_label_arrow_endpoints.py @@ -9,7 +9,7 @@ NodeLabelMutateResult, NodeLabelWriteResult, ) -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter @@ -20,11 +20,11 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) self._arrow_client = arrow_client - self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client) + self._node_property_endpoints = NodePropertyEndpointsHelper(arrow_client, write_back_client) self._show_progress = show_progress def mutate( @@ -83,6 +83,11 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/graph.nodeLabel.stream", G, config, write_concurrency, concurrency + "v2/graph.nodeLabel.stream", + G, + config, + property_overwrites={}, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return NodeLabelWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/catalog/node_properties_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/catalog/node_properties_arrow_endpoints.py index 13461d4e6..8aece03ca 100644 --- a/graphdatascience/procedure_surface/arrow/catalog/node_properties_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/catalog/node_properties_arrow_endpoints.py @@ -14,7 +14,7 @@ NodePropertiesWriteResult, NodePropertySpec, ) -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter from graphdatascience.procedure_surface.utils.result_utils import join_db_node_properties @@ -31,7 +31,7 @@ def __init__( self._write_back_client: RemoteWriteBackClient | None = ( RemoteWriteBackClient(arrow_client, query_runner) if query_runner is not None else None ) - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, self._write_back_client, show_progress=show_progress ) self._show_progress = show_progress diff --git a/graphdatascience/procedure_surface/arrow/centrality/articlerank_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/centrality/articlerank_arrow_endpoints.py index aed314d12..fd595d75c 100644 --- a/graphdatascience/procedure_surface/arrow/centrality/articlerank_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/centrality/articlerank_arrow_endpoints.py @@ -12,7 +12,7 @@ ArticleRankWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class ArticleRankArrowEndpoints(ArticleRankEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -60,9 +60,7 @@ def mutate( tolerance=tolerance, ) - result = self._node_property_endpoints.run_job_and_mutate( - "v2/centrality.articleRank", G, config, mutate_property - ) + result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.articleRank", config, mutate_property) return ArticleRankMutateResult(**result) @@ -99,9 +97,7 @@ def stats( tolerance=tolerance, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary( - "v2/centrality.articleRank", G, config - ) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.articleRank", config) return ArticleRankStatsResult(**computation_result) @@ -176,7 +172,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/centrality.articleRank", G, config, write_concurrency, concurrency, write_property + "v2/centrality.articleRank", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return ArticleRankWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/centrality/articulationpoints_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/centrality/articulationpoints_arrow_endpoints.py index edf490541..23193a251 100644 --- a/graphdatascience/procedure_surface/arrow/centrality/articulationpoints_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/centrality/articulationpoints_arrow_endpoints.py @@ -12,7 +12,7 @@ ArticulationPointsWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class ArticulationPointsArrowEndpoints(ArticulationPointsEndpoints): @@ -24,7 +24,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -52,7 +52,7 @@ def mutate( ) result = self._node_property_endpoints.run_job_and_mutate( - "v2/centrality.articulationPoints", G, config, mutate_property + "v2/centrality.articulationPoints", config, mutate_property ) return ArticulationPointsMutateResult(**result) @@ -80,7 +80,7 @@ def stats( ) computation_result = self._node_property_endpoints.run_job_and_get_summary( - "v2/centrality.articulationPoints", G, config + "v2/centrality.articulationPoints", config ) return ArticulationPointsStatsResult(**computation_result) @@ -127,7 +127,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/centrality.articulationPoints", G, config, write_concurrency, concurrency, write_property + "v2/centrality.articulationPoints", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return ArticulationPointsWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/centrality/betweenness_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/centrality/betweenness_arrow_endpoints.py index 5aef9072a..92c7bf6b9 100644 --- a/graphdatascience/procedure_surface/arrow/centrality/betweenness_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/centrality/betweenness_arrow_endpoints.py @@ -12,7 +12,7 @@ BetweennessWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class BetweennessArrowEndpoints(BetweennessEndpoints): @@ -24,7 +24,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -57,9 +57,7 @@ def mutate( username=username, ) - result = self._node_property_endpoints.run_job_and_mutate( - "v2/centrality.betweenness", G, config, mutate_property - ) + result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.betweenness", config, mutate_property) return BetweennessMutateResult(**result) @@ -91,9 +89,7 @@ def stats( username=username, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary( - "v2/centrality.betweenness", G, config - ) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.betweenness", config) return BetweennessStatsResult(**computation_result) @@ -158,7 +154,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/centrality.betweenness", G, config, write_concurrency, concurrency, write_property + "v2/centrality.betweenness", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return BetweennessWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/centrality/celf_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/centrality/celf_arrow_endpoints.py index ebc325aef..35d03eefb 100644 --- a/graphdatascience/procedure_surface/arrow/centrality/celf_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/centrality/celf_arrow_endpoints.py @@ -12,7 +12,7 @@ CelfWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class CelfArrowEndpoints(CelfEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -56,7 +56,7 @@ def mutate( random_seed=random_seed, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.celf", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.celf", config, mutate_property) return CelfMutateResult(**result) @@ -89,7 +89,7 @@ def stats( random_seed=random_seed, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.celf", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.celf", config) return CelfStatsResult(**computation_result) diff --git a/graphdatascience/procedure_surface/arrow/centrality/closeness_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/centrality/closeness_arrow_endpoints.py index 0757b0600..f74d3412a 100644 --- a/graphdatascience/procedure_surface/arrow/centrality/closeness_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/centrality/closeness_arrow_endpoints.py @@ -12,7 +12,7 @@ ClosenessWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class ClosenessArrowEndpoints(ClosenessEndpoints): @@ -24,7 +24,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -53,7 +53,7 @@ def mutate( job_id=job_id, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.closeness", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.closeness", config, mutate_property) return ClosenessMutateResult(**result) @@ -81,7 +81,7 @@ def stats( job_id=job_id, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.closeness", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.closeness", config) return ClosenessStatsResult(**computation_result) @@ -138,7 +138,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/centrality.closeness", G, config, write_concurrency, concurrency, write_property + "v2/centrality.closeness", + G, + config, + write_concurrency=write_concurrency, + concurrency=concurrency, + property_overwrites=write_property, ) return ClosenessWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/centrality/closeness_harmonic_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/centrality/closeness_harmonic_arrow_endpoints.py index 5b543c74f..88d26d715 100644 --- a/graphdatascience/procedure_surface/arrow/centrality/closeness_harmonic_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/centrality/closeness_harmonic_arrow_endpoints.py @@ -12,7 +12,7 @@ ClosenessHarmonicWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class ClosenessHarmonicArrowEndpoints(ClosenessHarmonicEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -49,7 +49,7 @@ def mutate( username=username, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.harmonic", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.harmonic", config, mutate_property) return ClosenessHarmonicMutateResult(**result) @@ -75,7 +75,7 @@ def stats( username=username, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.harmonic", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.harmonic", config) return ClosenessHarmonicStatsResult(**computation_result) @@ -128,7 +128,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/centrality.harmonic", G, config, write_concurrency, concurrency, write_property + "v2/centrality.harmonic", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return ClosenessHarmonicWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/centrality/degree_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/centrality/degree_arrow_endpoints.py index fe5e5c98f..631d0eeaa 100644 --- a/graphdatascience/procedure_surface/arrow/centrality/degree_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/centrality/degree_arrow_endpoints.py @@ -12,7 +12,7 @@ DegreeWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class DegreeArrowEndpoints(DegreeEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -52,7 +52,7 @@ def mutate( sudo=sudo, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.degree", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.degree", config, mutate_property) return DegreeMutateResult(**result) @@ -81,7 +81,7 @@ def stats( sudo=sudo, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.degree", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.degree", config) return DegreeStatsResult(**computation_result) @@ -140,7 +140,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/centrality.degree", G, config, write_concurrency, concurrency, write_property + "v2/centrality.degree", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return DegreeWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/centrality/eigenvector_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/centrality/eigenvector_arrow_endpoints.py index 5709acf07..487e6a62d 100644 --- a/graphdatascience/procedure_surface/arrow/centrality/eigenvector_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/centrality/eigenvector_arrow_endpoints.py @@ -12,7 +12,7 @@ EigenvectorWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class EigenvectorArrowEndpoints(EigenvectorEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -59,9 +59,7 @@ def mutate( username=username, ) - result = self._node_property_endpoints.run_job_and_mutate( - "v2/centrality.eigenvector", G, config, mutate_property - ) + result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.eigenvector", config, mutate_property) return EigenvectorMutateResult(**result) @@ -97,9 +95,7 @@ def stats( username=username, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary( - "v2/centrality.eigenvector", G, config - ) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.eigenvector", config) return EigenvectorStatsResult(**computation_result) @@ -172,7 +168,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/centrality.eigenvector", G, config, write_concurrency, concurrency, write_property + "v2/centrality.eigenvector", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return EigenvectorWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/centrality/pagerank_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/centrality/pagerank_arrow_endpoints.py index 4ed30d3b8..66eab3bbb 100644 --- a/graphdatascience/procedure_surface/arrow/centrality/pagerank_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/centrality/pagerank_arrow_endpoints.py @@ -12,7 +12,7 @@ PageRankWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class PageRankArrowEndpoints(PageRankEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = False, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -60,7 +60,7 @@ def mutate( tolerance=tolerance, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.pageRank", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/centrality.pageRank", config, mutate_property) return PageRankMutateResult(**result) @@ -97,7 +97,7 @@ def stats( tolerance=tolerance, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.pageRank", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/centrality.pageRank", config) return PageRankStatsResult(**computation_result) @@ -172,7 +172,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/centrality.pageRank", G, config, write_concurrency, concurrency, write_property + "v2/centrality.pageRank", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return PageRankWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/clique_counting_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/clique_counting_arrow_endpoints.py index 889f6b294..085deb399 100644 --- a/graphdatascience/procedure_surface/arrow/community/clique_counting_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/clique_counting_arrow_endpoints.py @@ -12,7 +12,7 @@ CliqueCountingWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class CliqueCountingArrowEndpoints(CliqueCountingEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -50,7 +50,7 @@ def mutate( ) result = self._node_property_endpoints.run_job_and_mutate( - "v2/community.cliquecounting", G, config, mutate_property + "v2/community.cliquecounting", config, mutate_property ) return CliqueCountingMutateResult(**result) @@ -78,7 +78,7 @@ def stats( ) computation_result = self._node_property_endpoints.run_job_and_get_summary( - "v2/community.cliquecounting", G, config + "v2/community.cliquecounting", config ) return CliqueCountingStatsResult(**computation_result) @@ -132,7 +132,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.cliquecounting", G, config, write_concurrency, concurrency, write_property + "v2/community.cliquecounting", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return CliqueCountingWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/hdbscan_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/hdbscan_arrow_endpoints.py index 78f105cab..1b46b7454 100644 --- a/graphdatascience/procedure_surface/arrow/community/hdbscan_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/hdbscan_arrow_endpoints.py @@ -12,7 +12,7 @@ HdbscanWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class HdbscanArrowEndpoints(HdbscanEndpoints): @@ -25,7 +25,7 @@ def __init__( self._arrow_client = arrow_client self._write_back_client = write_back_client self._show_progress = show_progress - self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client, show_progress) + self._node_property_endpoints = NodePropertyEndpointsHelper(arrow_client, write_back_client, show_progress) def mutate( self, @@ -59,7 +59,7 @@ def mutate( username=username, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.hdbscan", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.hdbscan", config, mutate_property) return HdbscanMutateResult(**result) @@ -94,7 +94,7 @@ def stats( username=username, ) - result = self._node_property_endpoints.run_job_and_get_summary("v2/community.hdbscan", G, config) + result = self._node_property_endpoints.run_job_and_get_summary("v2/community.hdbscan", config) return HdbscanStatsResult(**result) @@ -166,7 +166,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.hdbscan", G, config, write_concurrency, concurrency, write_property + "v2/community.hdbscan", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return HdbscanWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/k1coloring_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/k1coloring_arrow_endpoints.py index 97d30c1bc..ed03dc7db 100644 --- a/graphdatascience/procedure_surface/arrow/community/k1coloring_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/k1coloring_arrow_endpoints.py @@ -12,7 +12,7 @@ K1ColoringWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class K1ColoringArrowEndpoints(K1ColoringEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -53,7 +53,7 @@ def mutate( username=username, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.k1coloring", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.k1coloring", config, mutate_property) return K1ColoringMutateResult(**result) @@ -83,7 +83,7 @@ def stats( username=username, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.k1coloring", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.k1coloring", config) return K1ColoringStatsResult(**computation_result) @@ -148,7 +148,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.k1coloring", G, config, write_concurrency, concurrency, write_property + "v2/community.k1coloring", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return K1ColoringWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/kcore_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/kcore_arrow_endpoints.py index bf3646d19..58dab35f5 100644 --- a/graphdatascience/procedure_surface/arrow/community/kcore_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/kcore_arrow_endpoints.py @@ -12,7 +12,7 @@ KCoreWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class KCoreArrowEndpoints(KCoreEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -48,7 +48,7 @@ def mutate( sudo=sudo, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.kcore", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.kcore", config, mutate_property) return KCoreMutateResult(**result) @@ -73,7 +73,7 @@ def stats( sudo=sudo, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.kcore", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.kcore", config) return KCoreStatsResult(**computation_result) @@ -128,7 +128,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.kcore", G, config, write_concurrency, concurrency, write_property + "v2/community.kcore", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return KCoreWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/kmeans_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/kmeans_arrow_endpoints.py index f723ae73a..f032808ed 100644 --- a/graphdatascience/procedure_surface/arrow/community/kmeans_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/kmeans_arrow_endpoints.py @@ -12,7 +12,7 @@ KMeansWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class KMeansArrowEndpoints(KMeansEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -68,7 +68,7 @@ def mutate( username=username, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.kmeans", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.kmeans", config, mutate_property) return KMeansMutateResult(**result) @@ -113,7 +113,7 @@ def stats( username=username, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.kmeans", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.kmeans", config) return KMeansStatsResult(**computation_result) @@ -204,7 +204,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.kmeans", G, config, write_concurrency, concurrency, write_property + "v2/community.kmeans", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return KMeansWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/labelpropagation_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/labelpropagation_arrow_endpoints.py index 03cf69191..9824403e1 100644 --- a/graphdatascience/procedure_surface/arrow/community/labelpropagation_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/labelpropagation_arrow_endpoints.py @@ -12,7 +12,7 @@ LabelPropagationWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class LabelPropagationArrowEndpoints(LabelPropagationEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -61,7 +61,7 @@ def mutate( ) result = self._node_property_endpoints.run_job_and_mutate( - "v2/community.labelPropagation", G, config, mutate_property + "v2/community.labelPropagation", config, mutate_property ) return LabelPropagationMutateResult(**result) @@ -100,7 +100,7 @@ def stats( ) computation_result = self._node_property_endpoints.run_job_and_get_summary( - "v2/community.labelPropagation", G, config + "v2/community.labelPropagation", config ) return LabelPropagationStatsResult(**computation_result) @@ -180,7 +180,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.labelPropagation", G, config, write_concurrency, concurrency, write_property + "v2/community.labelPropagation", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return LabelPropagationWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/leiden_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/leiden_arrow_endpoints.py index 8cca6720d..1c3ea4d0f 100644 --- a/graphdatascience/procedure_surface/arrow/community/leiden_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/leiden_arrow_endpoints.py @@ -12,7 +12,7 @@ LeidenWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class LeidenArrowEndpoints(LeidenEndpoints): @@ -25,7 +25,7 @@ def __init__( self._arrow_client = arrow_client self._write_back_client = write_back_client self._show_progress = show_progress - self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client, show_progress) + self._node_property_endpoints = NodePropertyEndpointsHelper(arrow_client, write_back_client, show_progress) def mutate( self, @@ -69,7 +69,7 @@ def mutate( username=username, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.leiden", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.leiden", config, mutate_property) return LeidenMutateResult(**result) @@ -114,7 +114,7 @@ def stats( username=username, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.leiden", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.leiden", config) return LeidenStatsResult(**computation_result) @@ -209,7 +209,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.leiden", G, config, write_concurrency, concurrency, write_property + "v2/community.leiden", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return LeidenWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/local_clustering_coefficient_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/local_clustering_coefficient_arrow_endpoints.py index 014ab12f7..974c1e940 100644 --- a/graphdatascience/procedure_surface/arrow/community/local_clustering_coefficient_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/local_clustering_coefficient_arrow_endpoints.py @@ -10,7 +10,7 @@ LocalClusteringCoefficientWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class LocalClusteringCoefficientArrowEndpoints(LocalClusteringCoefficientEndpoints): @@ -20,7 +20,7 @@ def __init__( remote_write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( client, remote_write_back_client, show_progress, @@ -53,7 +53,7 @@ def mutate( ) result = self._node_property_endpoints.run_job_and_mutate( - "v2/community.localClusteringCoefficient", G, config, mutate_property + "v2/community.localClusteringCoefficient", config, mutate_property ) return LocalClusteringCoefficientMutateResult(**result) @@ -85,7 +85,6 @@ def stats( result = self._node_property_endpoints.run_job_and_get_summary( "v2/community.localClusteringCoefficient", - G, config, ) @@ -136,7 +135,6 @@ def write( triangle_count_property: str | None = None, username: str | None = None, write_concurrency: int | None = None, - write_to_result_store: bool | None = None, ) -> LocalClusteringCoefficientWriteResult: config = self._node_property_endpoints.create_base_config( G, @@ -150,16 +148,15 @@ def write( triangle_count_property=triangle_count_property, username=username, write_concurrency=write_concurrency, - write_to_result_store=write_to_result_store, ) result = self._node_property_endpoints.run_job_and_write( "v2/community.localClusteringCoefficient", G, config, - write_concurrency, - concurrency, - write_property, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return LocalClusteringCoefficientWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/louvain_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/louvain_arrow_endpoints.py index cfd74e6b3..865991276 100644 --- a/graphdatascience/procedure_surface/arrow/community/louvain_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/louvain_arrow_endpoints.py @@ -12,7 +12,7 @@ LouvainWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class LouvainArrowEndpoints(LouvainEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -62,7 +62,7 @@ def mutate( tolerance=tolerance, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.louvain", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.louvain", config, mutate_property) return LouvainMutateResult(**result) @@ -101,7 +101,7 @@ def stats( tolerance=tolerance, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.louvain", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.louvain", config) return LouvainStatsResult(**computation_result) @@ -185,7 +185,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.louvain", G, config, write_concurrency, concurrency, write_property + "v2/community.louvain", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return LouvainWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/maxkcut_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/maxkcut_arrow_endpoints.py index 30840149a..93be872ac 100644 --- a/graphdatascience/procedure_surface/arrow/community/maxkcut_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/maxkcut_arrow_endpoints.py @@ -10,7 +10,7 @@ MaxKCutMutateResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class MaxKCutArrowEndpoints(MaxKCutEndpoints): @@ -23,7 +23,7 @@ def __init__( self._arrow_client = arrow_client self._write_back_client = write_back_client self._show_progress = show_progress - self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client, show_progress) + self._node_property_endpoints = NodePropertyEndpointsHelper(arrow_client, write_back_client, show_progress) def mutate( self, @@ -59,7 +59,7 @@ def mutate( vns_max_neighborhood_order=vns_max_neighborhood_order, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.maxkcut", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.maxkcut", config, mutate_property) return MaxKCutMutateResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/modularity_optimization_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/modularity_optimization_arrow_endpoints.py index ba319e809..b003a494f 100644 --- a/graphdatascience/procedure_surface/arrow/community/modularity_optimization_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/modularity_optimization_arrow_endpoints.py @@ -12,7 +12,7 @@ ModularityOptimizationWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class ModularityOptimizationArrowEndpoints(ModularityOptimizationEndpoints): @@ -26,7 +26,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = False, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -67,7 +67,7 @@ def mutate( ) result = self._node_property_endpoints.run_job_and_mutate( - "v2/community.modularityOptimization", G, config, mutate_property + "v2/community.modularityOptimization", config, mutate_property ) return ModularityOptimizationMutateResult(**result) @@ -107,7 +107,7 @@ def stats( username=username, ) - result = self._node_property_endpoints.run_job_and_get_summary("v2/community.modularityOptimization", G, config) + result = self._node_property_endpoints.run_job_and_get_summary("v2/community.modularityOptimization", config) return ModularityOptimizationStatsResult(**result) @@ -170,7 +170,6 @@ def write( tolerance: float | None = None, username: str | None = None, write_concurrency: int | None = None, - write_to_result_store: bool | None = None, ) -> ModularityOptimizationWriteResult: config = self._node_property_endpoints.create_base_config( G, @@ -189,11 +188,15 @@ def write( tolerance=tolerance, username=username, write_concurrency=write_concurrency, - write_to_result_store=write_to_result_store, ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.modularityOptimization", G, config, write_concurrency, concurrency, write_property + "v2/community.modularityOptimization", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return ModularityOptimizationWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/scc_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/scc_arrow_endpoints.py index 7b8a70024..933c19c5c 100644 --- a/graphdatascience/procedure_surface/arrow/community/scc_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/scc_arrow_endpoints.py @@ -12,7 +12,7 @@ SccWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class SccArrowEndpoints(SccEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -50,7 +50,7 @@ def mutate( sudo=sudo, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.scc", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.scc", config, mutate_property) return SccMutateResult(**result) @@ -77,7 +77,7 @@ def stats( sudo=sudo, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.scc", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.scc", config) return SccStatsResult(**computation_result) @@ -132,7 +132,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.scc", G, config, write_concurrency, concurrency, write_property + "v2/community.scc", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return SccWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/sllpa_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/sllpa_arrow_endpoints.py index 11e29f575..4e9478a89 100644 --- a/graphdatascience/procedure_surface/arrow/community/sllpa_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/sllpa_arrow_endpoints.py @@ -12,7 +12,7 @@ SllpaWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class SllpaArrowEndpoints(SllpaEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = False, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -56,7 +56,7 @@ def mutate( username=username, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.sllpa", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.sllpa", config, mutate_property) return SllpaMutateResult(**result) @@ -89,7 +89,7 @@ def stats( username=username, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.sllpa", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.sllpa", config) return SllpaStatsResult(**computation_result) @@ -156,7 +156,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.sllpa", G, config, write_concurrency, concurrency, write_property + "v2/community.sllpa", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return SllpaWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/triangle_count_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/triangle_count_arrow_endpoints.py index 641c280cc..54e8b803a 100644 --- a/graphdatascience/procedure_surface/arrow/community/triangle_count_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/triangle_count_arrow_endpoints.py @@ -12,7 +12,7 @@ TriangleCountWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class TriangleCountArrowEndpoints(TriangleCountEndpoints): @@ -25,7 +25,7 @@ def __init__( self._arrow_client = arrow_client self._write_back_client = write_back_client self._show_progress = show_progress - self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client, show_progress) + self._node_property_endpoints = NodePropertyEndpointsHelper(arrow_client, write_back_client, show_progress) def mutate( self, @@ -55,9 +55,7 @@ def mutate( username=username, ) - result = self._node_property_endpoints.run_job_and_mutate( - "v2/community.triangleCount", G, config, mutate_property - ) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.triangleCount", config, mutate_property) return TriangleCountMutateResult(**result) @@ -88,9 +86,7 @@ def stats( username=username, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary( - "v2/community.triangleCount", G, config - ) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.triangleCount", config) return TriangleCountStatsResult(**computation_result) @@ -154,7 +150,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.triangleCount", G, config, write_concurrency, concurrency, write_property + "v2/community.triangleCount", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return TriangleCountWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/community/wcc_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/community/wcc_arrow_endpoints.py index 682902a75..2acc61157 100644 --- a/graphdatascience/procedure_surface/arrow/community/wcc_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/community/wcc_arrow_endpoints.py @@ -12,7 +12,7 @@ WccWriteResult, ) from graphdatascience.procedure_surface.api.estimation_result import EstimationResult -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class WccArrowEndpoints(WccEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -56,7 +56,7 @@ def mutate( threshold=threshold, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/community.wcc", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/community.wcc", config, mutate_property) return WccMutateResult(**result) @@ -89,7 +89,7 @@ def stats( threshold=threshold, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.wcc", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/community.wcc", config) return WccStatsResult(**computation_result) @@ -160,7 +160,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/community.wcc", G, config, write_concurrency, concurrency, write_property + "v2/community.wcc", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return WccWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/endpoints_helper_base.py b/graphdatascience/procedure_surface/arrow/endpoints_helper_base.py new file mode 100644 index 000000000..72b31505d --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/endpoints_helper_base.py @@ -0,0 +1,148 @@ +from typing import Any + +from pandas import DataFrame + +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 + +from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from ...arrow_client.v2.data_mapper_utils import deserialize_single +from ...arrow_client.v2.job_client import JobClient +from ...arrow_client.v2.mutation_client import MutationClient +from ...arrow_client.v2.remote_write_back_client import RemoteWriteBackClient +from ..api.estimation_result import EstimationResult +from ..utils.config_converter import ConfigConverter + + +class EndpointsHelperBase: + def __init__( + self, + arrow_client: AuthenticatedArrowClient, + write_back_client: RemoteWriteBackClient | None = None, + show_progress: bool = True, + ): + self._arrow_client = arrow_client + self._write_back_client = write_back_client + self._show_progress = show_progress + + def run_job_and_get_summary(self, endpoint: str, config: dict[str, Any]) -> dict[str, Any]: + """Run a job and return the computation summary.""" + show_progress: bool = config.get("logProgress", True) and self._show_progress + + job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress) + result = JobClient.get_summary(self._arrow_client, job_id) + if config := result.get("configuration", None): + self._drop_write_internals(config) + return result + + def _run_job_and_mutate( + self, + endpoint: str, + config: dict[str, Any], + *, + mutate_property: str | None = None, + mutate_relationship_type: str | None = None, + ) -> dict[str, Any]: + """Run a job, mutate node properties, and return summary with mutation result.""" + show_progress = config.get("logProgress", True) and self._show_progress + job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress) + + if mutate_relationship_type: + mutate_result = MutationClient.mutate_relationship_property( + self._arrow_client, job_id, mutate_relationship_type, mutate_property + ) + elif mutate_property: + mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property) + else: + raise ValueError("Either mutate_property or mutate_relationship_type must be provided for mutation.") + + computation_result = JobClient.get_summary(self._arrow_client, job_id) + + # modify computation result to include mutation details + computation_result["nodePropertiesWritten"] = mutate_result.node_properties_written + computation_result["mutateMillis"] = mutate_result.mutate_millis + + if (config := computation_result.get("configuration", None)) is not None: + config["mutateProperty"] = mutate_property + if mutate_relationship_type is not None: + config["mutateRelationshipType"] = mutate_relationship_type + self._drop_write_internals(config) + + return computation_result + + def run_job_and_stream(self, endpoint: str, G: GraphV2, config: dict[str, Any]) -> DataFrame: + """Run a job and return streamed results.""" + show_progress = config.get("logProgress", True) and self._show_progress + job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress=show_progress) + return JobClient.stream_results(self._arrow_client, G.name(), job_id) + + def _run_job_and_write( + self, + endpoint: str, + G: GraphV2, + config: dict[str, Any], + *, + relationship_type_overwrite: str | None = None, + property_overwrites: str | dict[str, str] | None = None, + write_concurrency: int | None, + concurrency: int | None, + ) -> dict[str, Any]: + """Run a job, write results, and return summary with write time.""" + show_progress = config.get("logProgress", True) and self._show_progress + job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress=show_progress) + computation_result = JobClient.get_summary(self._arrow_client, job_id) + + if self._write_back_client is None: + raise Exception("Write back client is not initialized") + + if isinstance(property_overwrites, str): + # The remote write back procedure allows specifying a single overwrite. The key is ignored. + property_overwrites = {property_overwrites: property_overwrites} + + write_result = self._write_back_client.write( + G.name(), + job_id, + concurrency=write_concurrency if write_concurrency is not None else concurrency, + property_overwrites=property_overwrites, + relationship_type_overwrite=relationship_type_overwrite, + log_progress=show_progress, + ) + + # modify computation result to include write details + computation_result["writeMillis"] = write_result.write_millis + + return computation_result + + def create_base_config(self, G: GraphV2, **kwargs: Any) -> dict[str, Any]: + """Create base configuration with common parameters.""" + return ConfigConverter.convert_to_gds_config(graph_name=G.name(), **kwargs) + + def create_estimate_config(self, **kwargs: Any) -> dict[str, Any]: + """Create configuration for estimation.""" + return ConfigConverter.convert_to_gds_config(**kwargs) + + def estimate( + self, + estimate_endpoint: str, + G: GraphV2 | dict[str, Any], + algo_config: dict[str, Any] | None = None, + ) -> EstimationResult: + """Estimate memory requirements for the algorithm.""" + if isinstance(G, GraphV2): + payload = {"graphName": G.name()} + elif isinstance(G, dict): + payload = G + else: + raise ValueError("Either graph_name or projection_config must be provided.") + + payload.update(algo_config or {}) + + res = self._arrow_client.do_action_with_retry(estimate_endpoint, payload) + + return EstimationResult(**deserialize_single(res)) + + def _drop_write_internals(self, config: dict[str, Any]) -> None: + config.pop("writeConcurrency", None) + config.pop("writeToResultStore", None) + config.pop("writeProperty", None) + config.pop("writeRelationshipType", None) + config.pop("writeMillis", None) diff --git a/graphdatascience/procedure_surface/arrow/node_embedding/fastrp_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/node_embedding/fastrp_arrow_endpoints.py index 946d81a66..f70d148b8 100644 --- a/graphdatascience/procedure_surface/arrow/node_embedding/fastrp_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/node_embedding/fastrp_arrow_endpoints.py @@ -12,7 +12,7 @@ FastRPStatsResult, FastRPWriteResult, ) -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class FastRPArrowEndpoints(FastRPEndpoints): @@ -22,7 +22,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -64,7 +64,7 @@ def mutate( sudo=sudo, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/embeddings.fastrp", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/embeddings.fastrp", config, mutate_property) return FastRPMutateResult(**result) @@ -105,7 +105,7 @@ def stats( sudo=sudo, ) - computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/embeddings.fastrp", G, config) + computation_result = self._node_property_endpoints.run_job_and_get_summary("v2/embeddings.fastrp", config) return FastRPStatsResult(**computation_result) @@ -188,7 +188,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/embeddings.fastrp", G, config, write_concurrency, concurrency, write_property + "v2/embeddings.fastrp", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return FastRPWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/node_embedding/graphsage_predict_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/node_embedding/graphsage_predict_arrow_endpoints.py index e2999e1af..72a398f21 100644 --- a/graphdatascience/procedure_surface/arrow/node_embedding/graphsage_predict_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/node_embedding/graphsage_predict_arrow_endpoints.py @@ -12,7 +12,7 @@ GraphSageWriteResult, ) from graphdatascience.procedure_surface.arrow.model_api_arrow import ModelApiArrow -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class GraphSagePredictArrowEndpoints(GraphSagePredictEndpoints): @@ -23,7 +23,7 @@ def __init__( show_progress: bool = True, ): self._arrow_client = arrow_client - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) self._model_api = ModelApiArrow(arrow_client) @@ -86,7 +86,12 @@ def write( ) raw_result = self._node_property_endpoints.run_job_and_write( - "v2/embeddings.graphSage", G, config, write_concurrency, concurrency, write_property + "v2/embeddings.graphSage", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return GraphSageWriteResult(**raw_result) @@ -120,7 +125,6 @@ def mutate( raw_result = self._node_property_endpoints.run_job_and_mutate( "v2/embeddings.graphSage", - G, config, mutate_property, ) diff --git a/graphdatascience/procedure_surface/arrow/node_embedding/graphsage_train_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/node_embedding/graphsage_train_arrow_endpoints.py index 8936b9b4c..5a4eb9002 100644 --- a/graphdatascience/procedure_surface/arrow/node_embedding/graphsage_train_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/node_embedding/graphsage_train_arrow_endpoints.py @@ -13,7 +13,7 @@ from graphdatascience.procedure_surface.arrow.node_embedding.graphsage_predict_arrow_endpoints import ( GraphSagePredictArrowEndpoints, ) -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class GraphSageTrainArrowEndpoints(GraphSageTrainEndpoints): @@ -25,7 +25,7 @@ def __init__( ): self._arrow_client = arrow_client self._write_back_client = write_back_client - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client=write_back_client, show_progress=show_progress ) self._model_api = ModelApiArrow(arrow_client) @@ -91,7 +91,7 @@ def __call__( random_seed=random_seed, ) - result = self._node_property_endpoints.run_job_and_get_summary("v2/embeddings.graphSage.train", G, config) + result = self._node_property_endpoints.run_job_and_get_summary("v2/embeddings.graphSage.train", config) model = GraphSageModelV2( model_name, diff --git a/graphdatascience/procedure_surface/arrow/node_embedding/hashgnn_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/node_embedding/hashgnn_arrow_endpoints.py index eb6d82cda..a4367e02f 100644 --- a/graphdatascience/procedure_surface/arrow/node_embedding/hashgnn_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/node_embedding/hashgnn_arrow_endpoints.py @@ -11,7 +11,7 @@ HashGNNMutateResult, HashGNNWriteResult, ) -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class HashGNNArrowEndpoints(HashGNNEndpoints): @@ -26,7 +26,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -75,7 +75,7 @@ def mutate( job_id=job_id, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/embeddings.hashgnn", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/embeddings.hashgnn", config, mutate_property) return HashGNNMutateResult(**result) @@ -173,7 +173,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/embeddings.hashgnn", G, config, write_concurrency, concurrency, write_property + "v2/embeddings.hashgnn", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return HashGNNWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/node_embedding/node2vec_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/node_embedding/node2vec_arrow_endpoints.py index 9f103ed27..41f89218d 100644 --- a/graphdatascience/procedure_surface/arrow/node_embedding/node2vec_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/node_embedding/node2vec_arrow_endpoints.py @@ -11,7 +11,7 @@ Node2VecMutateResult, Node2VecWriteResult, ) -from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpoints +from graphdatascience.procedure_surface.arrow.node_property_endpoints import NodePropertyEndpointsHelper class Node2VecArrowEndpoints(Node2VecEndpoints): @@ -21,7 +21,7 @@ def __init__( write_back_client: RemoteWriteBackClient | None = None, show_progress: bool = True, ): - self._node_property_endpoints = NodePropertyEndpoints( + self._node_property_endpoints = NodePropertyEndpointsHelper( arrow_client, write_back_client, show_progress=show_progress ) @@ -80,7 +80,7 @@ def mutate( random_seed=random_seed, ) - result = self._node_property_endpoints.run_job_and_mutate("v2/embeddings.node2vec", G, config, mutate_property) + result = self._node_property_endpoints.run_job_and_mutate("v2/embeddings.node2vec", config, mutate_property) return Node2VecMutateResult(**result) @@ -199,7 +199,12 @@ def write( ) result = self._node_property_endpoints.run_job_and_write( - "v2/embeddings.node2vec", G, config, write_concurrency, concurrency, write_property + "v2/embeddings.node2vec", + G, + config, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=concurrency, ) return Node2VecWriteResult(**result) diff --git a/graphdatascience/procedure_surface/arrow/node_property_endpoints.py b/graphdatascience/procedure_surface/arrow/node_property_endpoints.py index 9ccabc0e4..caee6cd5d 100644 --- a/graphdatascience/procedure_surface/arrow/node_property_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/node_property_endpoints.py @@ -1,133 +1,33 @@ from typing import Any -from pandas import DataFrame - from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 - -from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient -from ...arrow_client.v2.data_mapper_utils import deserialize_single -from ...arrow_client.v2.job_client import JobClient -from ...arrow_client.v2.mutation_client import MutationClient -from ...arrow_client.v2.remote_write_back_client import RemoteWriteBackClient -from ..api.estimation_result import EstimationResult -from ..utils.config_converter import ConfigConverter +from graphdatascience.procedure_surface.arrow.endpoints_helper_base import EndpointsHelperBase -class NodePropertyEndpoints: +class NodePropertyEndpointsHelper(EndpointsHelperBase): """ Helper class for Arrow algorithm endpoints that work with node properties. Provides common functionality for job execution, mutation, streaming, and writing. """ - def __init__( - self, - arrow_client: AuthenticatedArrowClient, - write_back_client: RemoteWriteBackClient | None = None, - show_progress: bool = True, - ): - self._arrow_client = arrow_client - self._write_back_client = write_back_client - self._show_progress = show_progress - - def run_job_and_get_summary(self, endpoint: str, G: GraphV2, config: dict[str, Any]) -> dict[str, Any]: - """Run a job and return the computation summary.""" - show_progress: bool = config.get("logProgress", True) and self._show_progress - - job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress) - result = JobClient.get_summary(self._arrow_client, job_id) - if config := result.get("configuration", None): - self._drop_write_internals(config) - return result - - def run_job_and_mutate( - self, endpoint: str, G: GraphV2, config: dict[str, Any], mutate_property: str - ) -> dict[str, Any]: - """Run a job, mutate node properties, and return summary with mutation result.""" - show_progress = config.get("logProgress", True) and self._show_progress - job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress) - mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property) - computation_result = JobClient.get_summary(self._arrow_client, job_id) - - # modify computation result to include mutation details - computation_result["nodePropertiesWritten"] = mutate_result.node_properties_written - computation_result["mutateMillis"] = mutate_result.mutate_millis - - if (config := computation_result.get("configuration", None)) is not None: - config["mutateProperty"] = mutate_property - self._drop_write_internals(config) - - return computation_result - - def run_job_and_stream(self, endpoint: str, G: GraphV2, config: dict[str, Any]) -> DataFrame: - """Run a job and return streamed results.""" - show_progress = config.get("logProgress", True) and self._show_progress - job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress=show_progress) - return JobClient.stream_results(self._arrow_client, G.name(), job_id) + def run_job_and_mutate(self, endpoint: str, config: dict[str, Any], mutate_property: str) -> dict[str, Any]: + return self._run_job_and_mutate(endpoint, config, mutate_property=mutate_property) def run_job_and_write( self, endpoint: str, G: GraphV2, config: dict[str, Any], + property_overwrites: str | dict[str, str], write_concurrency: int | None = None, concurrency: int | None = None, - property_overwrites: str | dict[str, str] | None = None, ) -> dict[str, Any]: - """Run a job, write results, and return summary with write time.""" - show_progress = config.get("logProgress", True) and self._show_progress - job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress=show_progress) - computation_result = JobClient.get_summary(self._arrow_client, job_id) - - if self._write_back_client is None: - raise Exception("Write back client is not initialized") - - if isinstance(property_overwrites, str): - # The remote write back procedure allows specifying a single overwrite. The key is ignored. - property_overwrites = {property_overwrites: property_overwrites} - - write_result = self._write_back_client.write( - G.name(), - job_id, - concurrency=write_concurrency if write_concurrency is not None else concurrency, + return self._run_job_and_write( + endpoint, + G, + config, property_overwrites=property_overwrites, - log_progress=show_progress, + relationship_type_overwrite=None, + write_concurrency=write_concurrency, + concurrency=concurrency, ) - - # modify computation result to include write details - computation_result["writeMillis"] = write_result.write_millis - - return computation_result - - def create_base_config(self, G: GraphV2, **kwargs: Any) -> dict[str, Any]: - """Create base configuration with common parameters.""" - return ConfigConverter.convert_to_gds_config(graph_name=G.name(), **kwargs) - - def create_estimate_config(self, **kwargs: Any) -> dict[str, Any]: - """Create configuration for estimation.""" - return ConfigConverter.convert_to_gds_config(**kwargs) - - def estimate( - self, - estimate_endpoint: str, - G: GraphV2 | dict[str, Any], - algo_config: dict[str, Any] | None = None, - ) -> EstimationResult: - """Estimate memory requirements for the algorithm.""" - if isinstance(G, GraphV2): - payload = {"graphName": G.name()} - elif isinstance(G, dict): - payload = G - else: - raise ValueError("Either graph_name or projection_config must be provided.") - - payload.update(algo_config or {}) - - res = self._arrow_client.do_action_with_retry(estimate_endpoint, payload) - - return EstimationResult(**deserialize_single(res)) - - def _drop_write_internals(self, config: dict[str, Any]) -> None: - config.pop("writeConcurrency", None) - config.pop("writeToResultStore", None) - config.pop("writeProperty", None) - config.pop("writeMillis", None) diff --git a/graphdatascience/procedure_surface/arrow/relationship_endpoints_helper.py b/graphdatascience/procedure_surface/arrow/relationship_endpoints_helper.py new file mode 100644 index 000000000..1fe4d6a71 --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/relationship_endpoints_helper.py @@ -0,0 +1,43 @@ +from typing import Any, Dict + +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.arrow.endpoints_helper_base import EndpointsHelperBase + + +class RelationshipEndpointsHelper(EndpointsHelperBase): + """ + Helper class for Arrow algorithm endpoints that work with relationships. + Provides common functionality for job execution, mutation, streaming, and writing. + """ + + def run_job_and_mutate( + self, endpoint: str, config: Dict[str, Any], mutate_property: str, mutate_relationship_type: str + ) -> Dict[str, Any]: + """Run a job, mutate node properties, and return summary with mutation result.""" + return self._run_job_and_mutate( + endpoint, + config, + mutate_property=mutate_property, + mutate_relationship_type=mutate_relationship_type, + ) + + def run_job_and_write( + self, + endpoint: str, + G: GraphV2, + config: dict[str, Any], + *, + relationship_type_overwrite: str, + property_overwrites: str | dict[str, str] | None = None, + write_concurrency: int | None, + concurrency: int | None, + ) -> dict[str, Any]: + return self._run_job_and_write( + endpoint, + G, + config, + relationship_type_overwrite=relationship_type_overwrite, + property_overwrites=property_overwrites, + write_concurrency=write_concurrency, + concurrency=concurrency, + ) diff --git a/graphdatascience/procedure_surface/arrow/similarity/__init__.py b/graphdatascience/procedure_surface/arrow/similarity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/arrow/similarity/knn_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/similarity/knn_arrow_endpoints.py new file mode 100644 index 000000000..846d3fdf8 --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/similarity/knn_arrow_endpoints.py @@ -0,0 +1,282 @@ +from __future__ import annotations + +from typing import Any + +from pandas import DataFrame + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.similarity.knn_endpoints import KnnEndpoints +from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints +from graphdatascience.procedure_surface.api.similarity.knn_results import ( + KnnMutateResult, + KnnStatsResult, + KnnWriteResult, +) +from graphdatascience.procedure_surface.arrow.relationship_endpoints_helper import RelationshipEndpointsHelper +from graphdatascience.procedure_surface.arrow.similarity.knn_filtered_arrow_endpoints import KnnFilteredArrowEndpoints +from graphdatascience.procedure_surface.arrow.stream_result_mapper import rename_similarity_stream_result + + +class KnnArrowEndpoints(KnnEndpoints): + def __init__( + self, + arrow_client: AuthenticatedArrowClient, + write_back_client: RemoteWriteBackClient | None = None, + show_progress: bool = False, + ): + self._endpoints_helper = RelationshipEndpointsHelper( + arrow_client, write_back_client=write_back_client, show_progress=show_progress + ) + + @property + def filtered(self) -> KnnFilteredEndpoints: + return KnnFilteredArrowEndpoints( + self._endpoints_helper._arrow_client, + self._endpoints_helper._write_back_client, + self._endpoints_helper._show_progress, + ) + + def mutate( + self, + G: GraphV2, + mutate_relationship_type: str, + mutate_property: str, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnMutateResult: + config = self._endpoints_helper.create_base_config( + G, + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + + result = self._endpoints_helper.run_job_and_mutate( + "v2/similarity.knn", config, mutate_property, mutate_relationship_type + ) + + return KnnMutateResult(**result) + + def stats( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnStatsResult: + config = self._endpoints_helper.create_base_config( + G, + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + + result = self._endpoints_helper.run_job_and_get_summary("v2/similarity.knn", config) + result["similarityPairs"] = result.pop("relationshipsWritten", 0) + return KnnStatsResult(**result) + + def stream( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> DataFrame: + config = self._endpoints_helper.create_base_config( + G, + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + result = self._endpoints_helper.run_job_and_stream("v2/similarity.knn", G, config) + rename_similarity_stream_result(result) + + return result + + def write( + self, + G: GraphV2, + write_relationship_type: str, + write_property: str, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + write_concurrency: int | None = None, + ) -> KnnWriteResult: + config = self._endpoints_helper.create_base_config( + G, + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + + result = self._endpoints_helper.run_job_and_write( + "v2/similarity.knn", + G, + config, + relationship_type_overwrite=write_relationship_type, + property_overwrites=write_property, + write_concurrency=write_concurrency, + concurrency=None, + ) + + return KnnWriteResult(**result) + + def estimate( + self, + G: GraphV2 | dict[str, Any], + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> EstimationResult: + config = self._endpoints_helper.create_estimate_config( + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + + return self._endpoints_helper.estimate("v2/similarity.knn.estimate", G, config) diff --git a/graphdatascience/procedure_surface/arrow/similarity/knn_filtered_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/similarity/knn_filtered_arrow_endpoints.py new file mode 100644 index 000000000..e1f1cd4dc --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/similarity/knn_filtered_arrow_endpoints.py @@ -0,0 +1,298 @@ +from typing import Any + +from pandas import DataFrame + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints +from graphdatascience.procedure_surface.api.similarity.knn_results import ( + KnnMutateResult, + KnnStatsResult, + KnnWriteResult, +) +from graphdatascience.procedure_surface.arrow.relationship_endpoints_helper import RelationshipEndpointsHelper +from graphdatascience.procedure_surface.arrow.stream_result_mapper import rename_similarity_stream_result + + +class KnnFilteredArrowEndpoints(KnnFilteredEndpoints): + def __init__( + self, + arrow_client: AuthenticatedArrowClient, + write_back_client: RemoteWriteBackClient | None = None, + show_progress: bool = False, + ): + self._endpoints_helper = RelationshipEndpointsHelper( + arrow_client, write_back_client=write_back_client, show_progress=show_progress + ) + + def mutate( + self, + G: GraphV2, + mutate_relationship_type: str, + mutate_property: str, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnMutateResult: + config = self._endpoints_helper.create_base_config( + G, + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + node_labels=node_labels, + relationship_types=relationship_types, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + similarityCutoff=similarity_cutoff, + perturbationRate=perturbation_rate, + deltaThreshold=delta_threshold, + sampleRate=sample_rate, + randomJoins=random_joins, + initialSampler=initial_sampler, + maxIterations=max_iterations, + topK=top_k, + randomSeed=random_seed, + concurrency=concurrency, + jobId=job_id, + logProgress=log_progress, + sudo=sudo, + username=username, + ) + + result = self._endpoints_helper.run_job_and_mutate( + "v2/similarity.knn.filtered", config, mutate_property, mutate_relationship_type + ) + + return KnnMutateResult(**result) + + def stats( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnStatsResult: + config = self._endpoints_helper.create_base_config( + G, + relationship_types=relationship_types, + node_labels=node_labels, + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + similarityCutoff=similarity_cutoff, + perturbationRate=perturbation_rate, + deltaThreshold=delta_threshold, + sampleRate=sample_rate, + randomJoins=random_joins, + initialSampler=initial_sampler, + maxIterations=max_iterations, + topK=top_k, + randomSeed=random_seed, + concurrency=concurrency, + jobId=job_id, + logProgress=log_progress, + sudo=sudo, + username=username, + ) + + result = self._endpoints_helper.run_job_and_get_summary("v2/similarity.knn.filtered", config) + result["similarityPairs"] = result.pop("relationshipsWritten", 0) + return KnnStatsResult(**result) + + def stream( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> DataFrame: + config = self._endpoints_helper.create_base_config( + G, + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + nodeLabels=node_labels, + relationshipTypes=relationship_types, + similarityCutoff=similarity_cutoff, + perturbationRate=perturbation_rate, + deltaThreshold=delta_threshold, + sampleRate=sample_rate, + randomJoins=random_joins, + initialSampler=initial_sampler, + maxIterations=max_iterations, + topK=top_k, + randomSeed=random_seed, + concurrency=concurrency, + jobId=job_id, + logProgress=log_progress, + sudo=sudo, + username=username, + ) + + result = self._endpoints_helper.run_job_and_stream("v2/similarity.knn.filtered", G, config) + rename_similarity_stream_result(result) + + return result + + def write( + self, + G: GraphV2, + write_relationship_type: str, + write_property: str, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + write_concurrency: int | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnWriteResult: + config = self._endpoints_helper.create_base_config( + G, + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + nodeLabels=node_labels, + relationshipTypes=relationship_types, + similarityCutoff=similarity_cutoff, + perturbationRate=perturbation_rate, + deltaThreshold=delta_threshold, + sampleRate=sample_rate, + randomJoins=random_joins, + initialSampler=initial_sampler, + maxIterations=max_iterations, + topK=top_k, + randomSeed=random_seed, + writeConcurrency=write_concurrency, + concurrency=concurrency, + jobId=job_id, + logProgress=log_progress, + sudo=sudo, + username=username, + ) + + result = self._endpoints_helper.run_job_and_write( + "v2/similarity.knn.filtered", + G, + config, + property_overwrites=write_property, + relationship_type_overwrite=write_relationship_type, + concurrency=concurrency, + write_concurrency=write_concurrency, + ) + + return KnnWriteResult(**result) + + def estimate( + self, + G: GraphV2 | dict[str, Any], + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + username: str | None = None, + concurrency: int | None = None, + ) -> EstimationResult: + config = self._endpoints_helper.create_estimate_config( + relationship_types=relationship_types, + node_labels=node_labels, + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + similarityCutoff=similarity_cutoff, + perturbationRate=perturbation_rate, + deltaThreshold=delta_threshold, + sampleRate=sample_rate, + randomJoins=random_joins, + initialSampler=initial_sampler, + maxIterations=max_iterations, + topK=top_k, + randomSeed=random_seed, + concurrency=concurrency, + sudo=sudo, + username=username, + ) + + return self._endpoints_helper.estimate("v2/similarity.knn.filtered.estimate", G, config) diff --git a/graphdatascience/procedure_surface/arrow/stream_result_mapper.py b/graphdatascience/procedure_surface/arrow/stream_result_mapper.py new file mode 100644 index 000000000..be9b84e2e --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/stream_result_mapper.py @@ -0,0 +1,7 @@ +from pandas import DataFrame + + +def rename_similarity_stream_result(result: DataFrame) -> None: + result.rename(columns={"sourceNodeId": "node1", "targetNodeId": "node2"}, inplace=True) + if "relationshipType" in result.columns: + result.drop(columns=["relationshipType"], inplace=True) diff --git a/graphdatascience/procedure_surface/cypher/community/local_clustering_coefficient_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/community/local_clustering_coefficient_cypher_endpoints.py index ee54c99c9..cdeca8e5b 100644 --- a/graphdatascience/procedure_surface/cypher/community/local_clustering_coefficient_cypher_endpoints.py +++ b/graphdatascience/procedure_surface/cypher/community/local_clustering_coefficient_cypher_endpoints.py @@ -134,7 +134,6 @@ def write( triangle_count_property: str | None = None, username: str | None = None, write_concurrency: int | None = None, - write_to_result_store: bool | None = None, ) -> LocalClusteringCoefficientWriteResult: config = ConfigConverter.convert_to_gds_config( write_property=write_property, @@ -147,7 +146,6 @@ def write( triangle_count_property=triangle_count_property, username=username, write_concurrency=write_concurrency, - write_to_result_store=write_to_result_store, ) # Run procedure and return results diff --git a/graphdatascience/procedure_surface/cypher/community/modularity_optimization_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/community/modularity_optimization_cypher_endpoints.py index 8ddc8fddc..34caa7323 100644 --- a/graphdatascience/procedure_surface/cypher/community/modularity_optimization_cypher_endpoints.py +++ b/graphdatascience/procedure_surface/cypher/community/modularity_optimization_cypher_endpoints.py @@ -182,7 +182,6 @@ def write( tolerance: float | None = None, username: str | None = None, write_concurrency: int | None = None, - write_to_result_store: bool | None = None, ) -> ModularityOptimizationWriteResult: config = ConfigConverter.convert_to_gds_config( write_property=write_property, @@ -201,7 +200,6 @@ def write( tolerance=tolerance, username=username, write_concurrency=write_concurrency, - write_to_result_store=write_to_result_store, ) params = CallParameters( diff --git a/graphdatascience/procedure_surface/cypher/similarity/__init__.py b/graphdatascience/procedure_surface/cypher/similarity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/procedure_surface/cypher/similarity/knn_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/similarity/knn_cypher_endpoints.py new file mode 100644 index 000000000..40b3008b3 --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/similarity/knn_cypher_endpoints.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +from typing import Any + +from pandas import DataFrame + +from graphdatascience.call_parameters import CallParameters +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.similarity.knn_endpoints import KnnEndpoints +from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints +from graphdatascience.procedure_surface.api.similarity.knn_results import ( + KnnMutateResult, + KnnStatsResult, + KnnWriteResult, +) +from graphdatascience.procedure_surface.cypher.estimation_utils import estimate_algorithm +from graphdatascience.procedure_surface.cypher.similarity.knn_filtered_cypher_endpoints import ( + KnnFilteredCypherEndpoints, +) +from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter +from graphdatascience.query_runner.query_runner import QueryRunner + + +class KnnCypherEndpoints(KnnEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + @property + def filtered(self) -> KnnFilteredEndpoints: + return KnnFilteredCypherEndpoints(self._query_runner) + + def mutate( + self, + G: GraphV2, + mutate_relationship_type: str, + mutate_property: str, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnMutateResult: + config = ConfigConverter.convert_to_gds_config( + mutateRelationshipType=mutate_relationship_type, + mutateProperty=mutate_property, + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure("gds.knn.mutate", params=params).iloc[0] + + return KnnMutateResult(**result.to_dict()) + + def stats( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> KnnStatsResult: + config = ConfigConverter.convert_to_gds_config( + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure("gds.knn.stats", params=params, logging=log_progress).iloc[0] + + return KnnStatsResult(**result.to_dict()) + + def stream( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> DataFrame: + config = ConfigConverter.convert_to_gds_config( + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure("gds.knn.stream", params=params, logging=log_progress) + + return result + + def write( + self, + G: GraphV2, + write_relationship_type: str, + write_property: str, + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + write_concurrency: int | None = None, + ) -> KnnWriteResult: + config = ConfigConverter.convert_to_gds_config( + writeRelationshipType=write_relationship_type, + writeProperty=write_property, + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + writeConcurrency=write_concurrency, + ) + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure("gds.knn.write", params=params, logging=log_progress).iloc[0] + + return KnnWriteResult(**result.to_dict()) + + def estimate( + self, + G: GraphV2 | dict[str, Any], + node_properties: str | list[str] | dict[str, str], + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: int | None = None, + job_id: str | None = None, + ) -> EstimationResult: + config = ConfigConverter.convert_to_gds_config( + nodeProperties=node_properties, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + + return estimate_algorithm( + endpoint="gds.knn.stats.estimate", query_runner=self._query_runner, G=G, algo_config=config + ) diff --git a/graphdatascience/procedure_surface/cypher/similarity/knn_filtered_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/similarity/knn_filtered_cypher_endpoints.py new file mode 100644 index 000000000..fbf054e63 --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/similarity/knn_filtered_cypher_endpoints.py @@ -0,0 +1,290 @@ +from typing import Any + +from pandas import DataFrame + +from graphdatascience.call_parameters import CallParameters +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.similarity.knn_filtered_endpoints import KnnFilteredEndpoints +from graphdatascience.procedure_surface.api.similarity.knn_results import ( + KnnMutateResult, + KnnStatsResult, + KnnWriteResult, +) +from graphdatascience.procedure_surface.cypher.estimation_utils import estimate_algorithm +from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter +from graphdatascience.query_runner.query_runner import QueryRunner + + +class KnnFilteredCypherEndpoints(KnnFilteredEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def mutate( + self, + G: GraphV2, + mutate_relationship_type: str, + mutate_property: str, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: Any | None = None, + job_id: Any | None = None, + ) -> KnnMutateResult: + config = ConfigConverter.convert_to_gds_config( + mutateRelationshipType=mutate_relationship_type, + mutateProperty=mutate_property, + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure("gds.knn.filtered.mutate", params=params).iloc[0] + + return KnnMutateResult(**result.to_dict()) + + def stats( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: Any | None = None, + job_id: Any | None = None, + ) -> KnnStatsResult: + config = ConfigConverter.convert_to_gds_config( + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure("gds.knn.filtered.stats", params=params, logging=log_progress).iloc[ + 0 + ] + + return KnnStatsResult(**result.to_dict()) + + def stream( + self, + G: GraphV2, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: Any | None = None, + job_id: Any | None = None, + ) -> DataFrame: + config = ConfigConverter.convert_to_gds_config( + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + return self._query_runner.call_procedure("gds.knn.filtered.stream", params=params, logging=log_progress) + + def write( + self, + G: GraphV2, + write_relationship_type: str, + write_property: str, + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + write_concurrency: int | None = None, + sudo: bool = False, + log_progress: bool = True, + username: str | None = None, + concurrency: Any | None = None, + job_id: Any | None = None, + ) -> KnnWriteResult: + config = ConfigConverter.convert_to_gds_config( + writeRelationshipType=write_relationship_type, + writeProperty=write_property, + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + writeConcurrency=write_concurrency, + sudo=sudo, + logProgress=log_progress, + username=username, + concurrency=concurrency, + jobId=job_id, + ) + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure("gds.knn.filtered.write", params=params, logging=log_progress).iloc[ + 0 + ] + + return KnnWriteResult(**result.to_dict()) + + def estimate( + self, + G: GraphV2 | dict[str, Any], + node_properties: str | list[str] | dict[str, str], + source_node_filter: str, + target_node_filter: str, + seed_target_nodes: bool | None = None, + top_k: int = 10, + similarity_cutoff: float = 0.0, + delta_threshold: float = 0.001, + max_iterations: int = 100, + sample_rate: float = 0.5, + perturbation_rate: float = 0.0, + random_joins: int = 10, + random_seed: int | None = None, + initial_sampler: str = "UNIFORM", + relationship_types: list[str] | None = None, + node_labels: list[str] | None = None, + sudo: bool = False, + username: str | None = None, + concurrency: Any | None = None, + ) -> EstimationResult: + config = ConfigConverter.convert_to_gds_config( + nodeProperties=node_properties, + sourceNodeFilter=source_node_filter, + targetNodeFilter=target_node_filter, + seedTargetNodes=seed_target_nodes, + topK=top_k, + similarityCutoff=similarity_cutoff, + deltaThreshold=delta_threshold, + maxIterations=max_iterations, + sampleRate=sample_rate, + perturbationRate=perturbation_rate, + randomJoins=random_joins, + randomSeed=random_seed, + initialSampler=initial_sampler, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + sudo=sudo, + username=username, + concurrency=concurrency, + ) + + return estimate_algorithm("gds.knn.filtered.stats.estimate", self._query_runner, G, config) diff --git a/graphdatascience/session/session_v2_endpoints.py b/graphdatascience/session/session_v2_endpoints.py index 6f366ea65..b2d73fc35 100644 --- a/graphdatascience/session/session_v2_endpoints.py +++ b/graphdatascience/session/session_v2_endpoints.py @@ -15,6 +15,7 @@ from graphdatascience.procedure_surface.api.community.sllpa_endpoints import SllpaEndpoints from graphdatascience.procedure_surface.api.community.triangle_count_endpoints import TriangleCountEndpoints from graphdatascience.procedure_surface.api.node_embedding.graphsage_endpoints import GraphSageEndpoints +from graphdatascience.procedure_surface.api.similarity.knn_endpoints import KnnEndpoints from graphdatascience.procedure_surface.arrow.catalog_arrow_endpoints import CatalogArrowEndpoints from graphdatascience.procedure_surface.arrow.centrality.articlerank_arrow_endpoints import ArticleRankArrowEndpoints from graphdatascience.procedure_surface.arrow.centrality.articulationpoints_arrow_endpoints import ( @@ -63,6 +64,7 @@ ) from graphdatascience.procedure_surface.arrow.node_embedding.hashgnn_arrow_endpoints import HashGNNArrowEndpoints from graphdatascience.procedure_surface.arrow.node_embedding.node2vec_arrow_endpoints import Node2VecArrowEndpoints +from graphdatascience.procedure_surface.arrow.similarity.knn_arrow_endpoints import KnnArrowEndpoints from graphdatascience.query_runner.query_runner import QueryRunner @@ -165,6 +167,10 @@ def k_core_decomposition(self) -> KCoreArrowEndpoints: def kmeans(self) -> KMeansEndpoints: return KMeansArrowEndpoints(self._arrow_client, self._write_back_client, show_progress=self._show_progress) + @property + def knn(self) -> KnnEndpoints: + return KnnArrowEndpoints(self._arrow_client, self._write_back_client, show_progress=self._show_progress) + @property def label_propagation(self) -> LabelPropagationEndpoints: return LabelPropagationArrowEndpoints( diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/conftest.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/conftest.py index 0354e7134..552538095 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/arrow/conftest.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/conftest.py @@ -8,7 +8,9 @@ from graphdatascience import QueryRunner from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.tests.integrationV2.procedure_surface.conftest import ( + GdsSessionConnectionInfo, create_arrow_client, create_db_query_runner, start_database, @@ -19,22 +21,24 @@ @pytest.fixture(scope="package") -def session_container( +def session_connection( network: Network, password_dir: Path, logs_dir: Path, inside_ci: bool -) -> Generator[DockerContainer, None, None]: +) -> Generator[GdsSessionConnectionInfo, None, None]: yield from start_session(inside_ci, logs_dir, network, password_dir) @pytest.fixture(scope="package") -def arrow_client(session_container: DockerContainer) -> AuthenticatedArrowClient: - return create_arrow_client(session_container) +def arrow_client(session_connection: DockerContainer) -> AuthenticatedArrowClient: + return create_arrow_client(session_connection) @pytest.fixture(scope="package") -def neo4j_container(network: Network, logs_dir: Path, inside_ci: bool) -> Generator[DockerContainer, None, None]: +def neo4j_connection(network: Network, logs_dir: Path, inside_ci: bool) -> Generator[DbmsConnectionInfo, None, None]: + if inside_ci: + raise RuntimeError("Communication between Session and DB is not supported yet in CI.") yield from start_database(inside_ci, logs_dir, network) @pytest.fixture(scope="package") -def query_runner(neo4j_container: DockerContainer) -> Generator[QueryRunner, None, None]: - yield from create_db_query_runner(neo4j_container) +def query_runner(neo4j_connection: DbmsConnectionInfo) -> Generator[QueryRunner, None, None]: + yield from create_db_query_runner(neo4j_connection) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/__init__.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_arrow_endpoints.py new file mode 100644 index 000000000..dbcc08d1e --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_arrow_endpoints.py @@ -0,0 +1,132 @@ +from typing import Generator + +import pytest + +from graphdatascience import QueryRunner +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.arrow.similarity.knn_arrow_endpoints import KnnArrowEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import ( + create_graph, + create_graph_from_db, +) + +graph = """ + CREATE + (a: Node {prop: [1.0, 2.0, 3.0]}), + (b: Node {prop: [2.0, 2.0, 4.0]}), + (c: Node {prop: [3.0, 2.0, 1.0]}), + (d: Node {prop: [4.0, 2.0, 0.0]}) + """ + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]: + with create_graph(arrow_client, "g", graph) as G: + yield G + + +@pytest.fixture +def db_graph(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner) -> Generator[GraphV2, None, None]: + with create_graph_from_db( + arrow_client, + query_runner, + "g", + graph, + """ + MATCH (n) + WITH gds.graph.project.remote(n, null, {sourceNodeProperties: properties(n)}) as g + RETURN g + """, + ) as g: + yield g + + +@pytest.fixture +def knn_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[KnnArrowEndpoints, None, None]: + yield KnnArrowEndpoints(arrow_client) + + +def test_knn_stats(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) -> None: + """Test KNN stats operation.""" + result = knn_endpoints.stats(G=sample_graph, node_properties=["prop"], top_k=2) + + assert result.ran_iterations > 0 + assert result.did_converge + assert result.compute_millis >= 0 + assert result.pre_processing_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.nodes_compared > 0 + assert result.similarity_pairs > 0 + assert result.node_pairs_considered > 0 + assert "p50" in result.similarity_distribution + + +def test_knn_stream(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) -> None: + """Test KNN stream operation.""" + result_df = knn_endpoints.stream( + G=sample_graph, + node_properties=["prop"], + top_k=2, + ) + + assert set(result_df.columns) == {"node1", "node2", "similarity"} + assert len(result_df) == 8 + + +def test_knn_mutate(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) -> None: + """Test KNN mutate operation.""" + result = knn_endpoints.mutate( + G=sample_graph, + mutate_relationship_type="SIMILAR", + mutate_property="similarity", + node_properties=["prop"], + top_k=2, + ) + + assert result.ran_iterations > 0 + assert result.did_converge + assert result.pre_processing_millis >= 0 + assert result.compute_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.mutate_millis >= 0 + assert result.relationships_written == sample_graph.node_count() * 2 + assert result.node_pairs_considered > 0 + + +@pytest.mark.db_integration +def test_knn_write(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: GraphV2) -> None: + """Test KNN write operation.""" + endpoints = KnnArrowEndpoints( + arrow_client, write_back_client=RemoteWriteBackClient(arrow_client, query_runner), show_progress=False + ) + + result = endpoints.write( + G=db_graph, write_relationship_type="SIMILAR", write_property="similarity", node_properties=["prop"], top_k=2 + ) + + assert result.ran_iterations > 0 + assert result.did_converge + assert result.pre_processing_millis >= 0 + assert result.compute_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.write_millis >= 0 + assert result.relationships_written == db_graph.node_count() * 2 + assert result.node_pairs_considered > 0 + + # Check that relationships were written to the database + count_result = query_runner.run_cypher("MATCH ()-[r:SIMILAR]->() RETURN COUNT(r) AS count") + assert count_result.squeeze() >= result.relationships_written + + +def test_knn_estimate(knn_endpoints: KnnArrowEndpoints, sample_graph: GraphV2) -> None: + result = knn_endpoints.estimate(sample_graph, node_properties=["prop"], top_k=2) + + assert result.node_count == 4 + assert result.relationship_count == 0 # No relationships in this graph + assert "Bytes" in result.required_memory + assert result.bytes_min > 0 + assert result.bytes_max > 0 + assert result.heap_percentage_min > 0 + assert result.heap_percentage_max > 0 diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_filtered_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_filtered_arrow_endpoints.py new file mode 100644 index 000000000..7c5383738 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/similarity/test_knn_filtered_arrow_endpoints.py @@ -0,0 +1,151 @@ +from typing import Generator + +import pytest + +from graphdatascience import QueryRunner +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.arrow.similarity.knn_filtered_arrow_endpoints import KnnFilteredArrowEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import ( + create_graph, + create_graph_from_db, +) + +graph = """ + CREATE + (a: SourceNode {prop: [1.0, 2.0, 3.0]}), + (b: SourceNode {prop: [2.0, 2.0, 4.0]}), + (c: TargetNode {prop: [3.0, 2.0, 1.0]}), + (d: TargetNode {prop: [4.0, 2.0, 0.0]}) + """ + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]: + with create_graph(arrow_client, "g", graph) as G: + yield G + + +@pytest.fixture +def db_graph(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner) -> Generator[GraphV2, None, None]: + with create_graph_from_db( + arrow_client, + query_runner, + "g", + graph, + """ + MATCH (n) + WITH gds.graph.project.remote(n, null, {sourceNodeProperties: properties(n), sourceNodeLabels: labels(n)}) as g + RETURN g + """, + ) as g: + yield g + + +@pytest.fixture +def knn_filtered_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[KnnFilteredArrowEndpoints, None, None]: + yield KnnFilteredArrowEndpoints(arrow_client) + + +def test_knn_filtered_stats(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None: + result = knn_filtered_endpoints.stats( + sample_graph, + node_properties="prop", + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert result.pre_processing_millis >= 0 + assert result.compute_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.nodes_compared >= 0 + assert result.similarity_pairs >= 0 + assert "p50" in result.similarity_distribution + assert result.did_converge + assert result.ran_iterations >= 0 + assert result.node_pairs_considered >= 0 + assert "concurrency" in result.configuration + + +def test_knn_filtered_stream(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None: + result_df = knn_filtered_endpoints.stream( + G=sample_graph, + node_properties=["prop"], + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert set(result_df.columns) == {"node1", "node2", "similarity"} + assert len(result_df) == 4 + + +def test_knn_filtered_mutate(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None: + result = knn_filtered_endpoints.mutate( + sample_graph, + node_properties="prop", + mutate_property="score", + mutate_relationship_type="SIMILAR_TO", + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert result.pre_processing_millis >= 0 + assert result.compute_millis >= 0 + assert result.mutate_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.nodes_compared >= 0 + assert result.relationships_written > 0 + assert "p50" in result.similarity_distribution + assert result.did_converge + assert result.ran_iterations > 0 + assert result.node_pairs_considered >= 0 + assert "concurrency" in result.configuration + + +@pytest.mark.db_integration +def test_knn_filtered_write( + arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: GraphV2 +) -> None: + endpoints = KnnFilteredArrowEndpoints( + arrow_client, write_back_client=RemoteWriteBackClient(arrow_client, query_runner), show_progress=False + ) + + result = endpoints.write( + db_graph, + node_properties="prop", + write_property="score", + write_relationship_type="SIMILAR_TO", + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert result.pre_processing_millis >= 0 + assert result.compute_millis >= 0 + assert result.write_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.nodes_compared >= 0 + assert result.relationships_written > 0 + assert result.did_converge + assert result.ran_iterations > 0 + assert result.node_pairs_considered >= 0 + assert "p50" in result.similarity_distribution + assert "concurrency" in result.configuration + + +def test_knn_filtered_estimate(knn_filtered_endpoints: KnnFilteredArrowEndpoints, sample_graph: GraphV2) -> None: + result = knn_filtered_endpoints.estimate( + sample_graph, + node_properties="prop", + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert result.required_memory is not None + assert result.tree_view is not None + assert result.map_view is not None diff --git a/graphdatascience/tests/integrationV2/procedure_surface/conftest.py b/graphdatascience/tests/integrationV2/procedure_surface/conftest.py index a2a6513eb..938c37b9b 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/conftest.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/conftest.py @@ -1,5 +1,6 @@ import logging import os +from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Generator @@ -14,10 +15,18 @@ from graphdatascience.arrow_client.arrow_info import ArrowInfo from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner +from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo LOGGER = logging.getLogger(__name__) +@dataclass +class GdsSessionConnectionInfo: + host: str + arrow_port: int + bolt_port: int + + @pytest.fixture(scope="package") def password_dir(tmp_path_factory: pytest.TempPathFactory) -> Generator[Path, None, None]: """Create a temporary file and return its path.""" @@ -47,7 +56,12 @@ def latest_neo4j_version() -> str: def start_session( inside_ci: bool, logs_dir: Path, network: Network, password_dir: Path -) -> Generator[DockerContainer, None, None]: +) -> Generator[GdsSessionConnectionInfo, None, None]: + if (session_uri := os.environ.get("GDS_SESSION_URI")) is not None: + uri_parts = session_uri.split(":") + yield GdsSessionConnectionInfo(host=uri_parts[0], arrow_port=8491, bolt_port=int(uri_parts[1])) + return + session_image = os.getenv( "GDS_SESSION_IMAGE", "europe-west1-docker.pkg.dev/gds-aura-artefacts/gds/gds-session:latest" ) @@ -66,7 +80,11 @@ def start_session( session_container = session_container.with_network(network).with_network_aliases("gds-session") with session_container as session_container: wait_for_logs(session_container, "Running GDS tasks: 0") - yield session_container + yield GdsSessionConnectionInfo( + host=session_container.get_container_host_ip(), + arrow_port=session_container.get_exposed_port(8491), + bolt_port=-1, # not used in tests + ) stdout, stderr = session_container.get_logs() if stderr: @@ -80,19 +98,18 @@ def start_session( f.write(stdout.decode("utf-8")) -def create_arrow_client(session_container: DockerContainer) -> AuthenticatedArrowClient: +def create_arrow_client(session_uri: GdsSessionConnectionInfo) -> AuthenticatedArrowClient: """Create an authenticated Arrow client connected to the session container.""" - host = session_container.get_container_host_ip() - port = session_container.get_exposed_port(8491) + return AuthenticatedArrowClient.create( - arrow_info=ArrowInfo(f"{host}:{port}", True, True, ["v1", "v2"]), + arrow_info=ArrowInfo(f"{session_uri.host}:{session_uri.arrow_port}", True, True, ["v1", "v2"]), auth=UsernamePasswordAuthentication("neo4j", "password"), encrypted=False, advertised_listen_address=("gds-session", 8491), ) -def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generator[DockerContainer, None, None]: +def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generator[DbmsConnectionInfo, None, None]: default_neo4j_image = ( f"europe-west1-docker.pkg.dev/neo4j-aura-image-artifacts/aura/neo4j-enterprise:{latest_neo4j_version()}" ) @@ -115,7 +132,11 @@ def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generat ) with db_container as db_container: wait_for_logs(db_container, "Started.") - yield db_container + yield DbmsConnectionInfo( + uri=f"{db_container.get_container_host_ip()}:{db_container.get_exposed_port(7687)}", + username="neo4j", + password="password", + ) stdout, stderr = db_container.get_logs() if stderr: @@ -129,11 +150,9 @@ def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generat f.write(stdout.decode("utf-8")) -def create_db_query_runner(neo4j_container: DockerContainer) -> Generator[Neo4jQueryRunner, None, None]: - host = neo4j_container.get_container_host_ip() - port = 7687 +def create_db_query_runner(neo4j_connection: DbmsConnectionInfo) -> Generator[Neo4jQueryRunner, None, None]: query_runner = Neo4jQueryRunner.create_for_db( - f"bolt://{host}:{port}", + f"bolt://{neo4j_connection.uri}", ("neo4j", "password"), ) yield query_runner diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/__init__.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/test_knn_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/test_knn_cypher_endpoints.py new file mode 100644 index 000000000..a85b904ae --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/test_knn_cypher_endpoints.py @@ -0,0 +1,94 @@ +from typing import Generator + +import pytest + +from graphdatascience import QueryRunner +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.cypher.similarity.knn_cypher_endpoints import KnnCypherEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import create_graph + + +@pytest.fixture +def sample_graph(query_runner: QueryRunner) -> Generator[GraphV2, None, None]: + create_statement = """ + CREATE + (a: Node {prop: [1.0, 2.0, 3.0]}), + (b: Node {prop: [2.0, 3.0, 4.0]}), + (c: Node {prop: [3.0, 4.0, 5.0]}), + (d: Node {prop: [1.0, 1.0, 1.0]}) + """ + + projection_query = """ + MATCH (n) + WITH gds.graph.project('g', n, null, {sourceNodeProperties: properties(n), targetNodeProperties: null}) AS G + RETURN G + """ + + with create_graph( + query_runner, + "g", + create_statement, + projection_query, + ) as g: + yield g + + +@pytest.fixture +def knn_endpoints(query_runner: QueryRunner) -> Generator[KnnCypherEndpoints, None, None]: + yield KnnCypherEndpoints(query_runner) + + +def test_knn_stats(knn_endpoints: KnnCypherEndpoints, sample_graph: GraphV2) -> None: + result = knn_endpoints.stats(G=sample_graph, node_properties=["prop"], top_k=2) + + assert result.ran_iterations > 0 + assert result.did_converge + assert result.compute_millis > 0 + assert result.pre_processing_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.nodes_compared > 0 + assert result.similarity_pairs == 8 + assert result.node_pairs_considered > 0 + assert "p50" in result.similarity_distribution + + +def test_knn_stream(knn_endpoints: KnnCypherEndpoints, sample_graph: GraphV2) -> None: + result_df = knn_endpoints.stream( + G=sample_graph, + node_properties=["prop"], + top_k=2, + ) + + assert set(result_df.columns) == {"node1", "node2", "similarity"} + assert len(result_df) == 8 + + +def test_knn_mutate(knn_endpoints: KnnCypherEndpoints, sample_graph: GraphV2) -> None: + result = knn_endpoints.mutate( + G=sample_graph, + mutate_relationship_type="SIMILAR", + mutate_property="similarity", + node_properties=["prop"], + top_k=2, + ) + + assert result.ran_iterations > 0 + assert result.did_converge + assert result.pre_processing_millis >= 0 + assert result.compute_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.mutate_millis >= 0 + assert result.relationships_written == 8 + assert result.node_pairs_considered > 0 + + +def test_knn_estimate(knn_endpoints: KnnCypherEndpoints, sample_graph: GraphV2) -> None: + result = knn_endpoints.estimate(sample_graph, node_properties=["prop"], top_k=2) + + assert result.node_count == 4 + assert result.relationship_count == 0 # No relationships in this graph + assert "Bytes" in result.required_memory + assert result.bytes_min > 0 + assert result.bytes_max > 0 + assert result.heap_percentage_min > 0 + assert result.heap_percentage_max > 0 diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/test_knn_filtered_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/test_knn_filtered_cypher_endpoints.py new file mode 100644 index 000000000..1f096892c --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/similarity/test_knn_filtered_cypher_endpoints.py @@ -0,0 +1,136 @@ +from typing import Generator + +import pytest + +from graphdatascience import QueryRunner +from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 +from graphdatascience.procedure_surface.cypher.similarity.knn_filtered_cypher_endpoints import ( + KnnFilteredCypherEndpoints, +) +from graphdatascience.tests.integrationV2.procedure_surface.cypher.cypher_graph_helper import create_graph + + +@pytest.fixture +def sample_graph(query_runner: QueryRunner) -> Generator[GraphV2, None, None]: + create_statement = """ + CREATE + (a: SourceNode {prop: [1.0, 2.0, 3.0]}), + (b: SourceNode {prop: [2.0, 3.0, 4.0]}), + (c: TargetNode {prop: [3.0, 4.0, 5.0]}), + (d: TargetNode {prop: [1.0, 1.0, 1.0]}) + """ + + projection_query = """ + MATCH (n) + WITH gds.graph.project('g', n, null, {sourceNodeProperties: properties(n), sourceNodeLabels: labels(n), targetNodeProperties: null, targetNodeLabels: null}) AS G + RETURN G + """ + + with create_graph( + query_runner, + "g", + create_statement, + projection_query, + ) as g: + yield g + + +@pytest.fixture +def knn_filtered_endpoints(query_runner: QueryRunner) -> Generator[KnnFilteredCypherEndpoints, None, None]: + yield KnnFilteredCypherEndpoints(query_runner) + + +def test_knn_filtered_stats(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None: + result = knn_filtered_endpoints.stats( + G=sample_graph, + node_properties=["prop"], + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert result.ran_iterations > 0 + assert result.did_converge + assert result.compute_millis > 0 + assert result.pre_processing_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.nodes_compared > 0 + assert result.similarity_pairs > 0 + assert "p50" in result.similarity_distribution + assert result.node_pairs_considered > 0 + assert "concurrency" in result.configuration + + +def test_knn_filtered_stream(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None: + result = knn_filtered_endpoints.stream( + G=sample_graph, + node_properties=["prop"], + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert set(result.columns) == {"node1", "node2", "similarity"} + assert len(result) >= 4 + + +def test_knn_filtered_mutate(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None: + result = knn_filtered_endpoints.mutate( + G=sample_graph, + node_properties=["prop"], + mutate_property="score", + mutate_relationship_type="SIMILAR_TO", + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert result.ran_iterations > 0 + assert result.did_converge + assert result.compute_millis > 0 + assert result.mutate_millis >= 0 + assert result.pre_processing_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.nodes_compared > 0 + assert result.relationships_written > 0 + assert "p50" in result.similarity_distribution + assert result.node_pairs_considered > 0 + assert "concurrency" in result.configuration + + +def test_knn_filtered_write(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None: + result = knn_filtered_endpoints.write( + G=sample_graph, + node_properties=["prop"], + write_property="score", + write_relationship_type="SIMILAR_TO", + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert result.ran_iterations > 0 + assert result.did_converge + assert result.compute_millis > 0 + assert result.write_millis >= 0 + assert result.pre_processing_millis >= 0 + assert result.post_processing_millis >= 0 + assert result.nodes_compared > 0 + assert result.relationships_written > 0 + assert "p50" in result.similarity_distribution + assert result.node_pairs_considered > 0 + assert "concurrency" in result.configuration + + +def test_knn_filtered_estimate(knn_filtered_endpoints: KnnFilteredCypherEndpoints, sample_graph: GraphV2) -> None: + result = knn_filtered_endpoints.estimate( + G=sample_graph, + node_properties=["prop"], + top_k=2, + source_node_filter="SourceNode", + target_node_filter="TargetNode", + ) + + assert result.required_memory is not None + assert result.tree_view is not None + assert result.map_view is not None diff --git a/graphdatascience/tests/integrationV2/procedure_surface/session/conftest.py b/graphdatascience/tests/integrationV2/procedure_surface/session/conftest.py index 7ce9806a5..373071230 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/session/conftest.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/session/conftest.py @@ -2,12 +2,13 @@ from typing import Generator import pytest -from testcontainers.core.container import DockerContainer from testcontainers.core.network import Network from graphdatascience import QueryRunner from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo from graphdatascience.tests.integrationV2.procedure_surface.conftest import ( + GdsSessionConnectionInfo, create_arrow_client, create_db_query_runner, start_database, @@ -16,22 +17,22 @@ @pytest.fixture(scope="package") -def session_container( +def session_connection( network: Network, password_dir: Path, logs_dir: Path, inside_ci: bool -) -> Generator[DockerContainer, None, None]: +) -> Generator[GdsSessionConnectionInfo, None, None]: yield from start_session(inside_ci, logs_dir, network, password_dir) @pytest.fixture(scope="package") -def arrow_client(session_container: DockerContainer) -> AuthenticatedArrowClient: - return create_arrow_client(session_container) +def arrow_client(session_connection: GdsSessionConnectionInfo) -> AuthenticatedArrowClient: + return create_arrow_client(session_connection) @pytest.fixture(scope="package") -def neo4j_container(network: Network, logs_dir: Path, inside_ci: bool) -> Generator[DockerContainer, None, None]: +def neo4j_connection(network: Network, logs_dir: Path, inside_ci: bool) -> Generator[DbmsConnectionInfo, None, None]: yield from start_database(inside_ci, logs_dir, network) @pytest.fixture(scope="package") -def db_query_runner(neo4j_container: DockerContainer) -> Generator[QueryRunner, None, None]: - yield from create_db_query_runner(neo4j_container) +def db_query_runner(neo4j_connection: DbmsConnectionInfo) -> Generator[QueryRunner, None, None]: + yield from create_db_query_runner(neo4j_connection) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/session/test_session_endpoint_coverage.py b/graphdatascience/tests/integrationV2/procedure_surface/session/test_session_endpoint_coverage.py index 6ebf69d66..222775a76 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/session/test_session_endpoint_coverage.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/session/test_session_endpoint_coverage.py @@ -3,19 +3,13 @@ import pytest -from graphdatascience import QueryRunner, ServerVersion from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient -from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience from graphdatascience.session.session_v2_endpoints import SessionV2Endpoints MISSING_ALGO_ENDPOINTS = { - "similarity.knn.filtered", - "similarity.knn.filtered.estimate", "similarity.nodeSimilarity.filtered", "similarity.nodeSimilarity.filtered.estimate", "similarity.nodeSimilarity", - "similarity.knn", - "similarity.knn.estimate", "similarity.nodeSimilarity.estimate", "pathfinding.sourceTarget.dijkstra.estimate", "pathfinding.sourceTarget.aStar", @@ -61,13 +55,8 @@ @pytest.fixture -def gds(arrow_client: AuthenticatedArrowClient, db_query_runner: QueryRunner) -> AuraGraphDataScience: - return AuraGraphDataScience( - query_runner=db_query_runner, - delete_fn=lambda: True, - gds_version=ServerVersion.from_string("2.7.0"), - v2_endpoints=SessionV2Endpoints(arrow_client, db_query_runner, show_progress=False), - ) +def endpoints(arrow_client: AuthenticatedArrowClient) -> SessionV2Endpoints: + return SessionV2Endpoints(arrow_client, db_client=None, show_progress=False) def to_snake(camel: str) -> str: @@ -102,8 +91,8 @@ def check_gds_v2_availability(endpoints: SessionV2Endpoints, algo: str) -> bool: @pytest.mark.db_integration -def test_algo_coverage(gds: AuraGraphDataScience) -> None: - arrow_client = gds.v2._arrow_client +def test_algo_coverage(endpoints: SessionV2Endpoints) -> None: + arrow_client = endpoints._arrow_client # Get all available Arrow actions available_v2_actions = [ @@ -128,7 +117,7 @@ def test_algo_coverage(gds: AuraGraphDataScience) -> None: for category, algos in algos_per_category.items(): for algo in algos: is_available = check_gds_v2_availability( - gds.v2, + endpoints, algo, ) action = f"{category}.{algo}"