STM-Graph is a Python framework for analyzing spatial-temporal urban data and doing predictions using Graph Neural Networks. It provides a complete end-to-end pipeline from raw event data to trained GNN models, making it easier to understand and predict urban events.
- Complete Pipeline: From raw data to trained models in a unified framework
- Flexible Spatial Mapping: Grid-based, Voronoi, or administrative boundary mapping
- Urban Features Graph: Extract features from OpenStreetMap for urban context
- Multiple GNN Models: Support for various graph neural networks
- Visualization Tools: Comprehensive spatial and temporal visualizations
- Integration: Weights & Biases integration for experiment tracking
STM-Graph requires PyTorch with the appropriate CUDA version for your system.
# First install the base package
pip install stm-gnn
# Then install PyTorch with CUDA
pip install torch==2.4.0 --extra-index-url https://download.pytorch.org/whl/cu118
# Finally install the PyTorch extensions
pip install stm-gnn[torch-extensions]
# Clone the repository
git clone https://github.com/Ahghaffari/stm_graph.git
cd stm_graph
# Install dependencies
pip install -r requirements.txt
import stm_graph
# 1. Preprocess your data
gdf = stm_graph.preprocess_dataset(
data_path="data/",
dataset="events.csv",
time_col="timestamp",
lat_col="latitude",
lng_col="longitude"
)
# 2. Create spatial mapping
mapper = stm_graph.GridMapping(cell_size=1000.0)
district_gdf, point_to_partition = mapper.create_mapping(gdf)
# 3. Extract urban features
osm_features = stm_graph.extract_osm_features(
regions_gdf=district_gdf,
feature_types=['poi', 'road', 'junction']
)
# 4. Build graph representation
graph_data = stm_graph.build_graph_and_augment(
grid_gdf=district_gdf,
points_gdf=gdf,
point_to_cell=point_to_partition,
static_features=osm_features
)
# 5. Create temporal dataset
temporal_dataset, _, _ = stm_graph.create_temporal_dataset(
edge_index=graph_data["edge_index"],
augmented_df=graph_data["augmented_df"],
node_ids=graph_data["node_ids"],
static_features=osm_features,
time_col="timestamp",
bin_type="daily"
)
# 6. Train a model
model = stm_graph.create_model("stgcn", task="classification")
results = stm_graph.train_model(
model=model,
dataset=temporal_dataset,
task="classification"
)
The repository includes two example notebooks in the examples/
folder that demonstrate the complete workflow:
- NYC 311 Service Request Analysis (
examples/nyc_311_example.ipynb
): Analyzing and predicting 311 service requests across New York City - NYC Traffic Crash Analysis (
examples/nyc_crash_example.ipynb
): Analyzing and predicting traffic crashes across New York City
These notebooks showcase the complete workflow from data preprocessing to model training and visualization. They are excellent starting points to understand how to use the STM-Graph framework with real-world datasets.
We evaluated STM-Graph on two publicly available urban datasets from New York City:
- NYC 311 Service Requests dataset (link), which includes various citizen-reported non-emergency issues
- Motor Vehicle Collisions dataset (link), detailing traffic collision incidents across the city
These datasets were selected for their richness, widespread availability, and relevance to urban planning and public safety use cases. You can use these datasets directly with the provided notebook examples.
For administrative boundary mapping, you can use NYC's administrative divisions such as:
Load and preprocess spatial-temporal data:
gdf = stm_graph.preprocess_dataset(
data_path="data",
dataset="events.csv",
time_col="timestamp",
lat_col="latitude",
lng_col="longitude",
filter_dates=("2020-01-01", "2020-12-31")
)
Divide the area into spatial regions:
# Grid-based mapping
mapper = stm_graph.GridMapping(cell_size=1000.0)
district_gdf, point_to_partition = mapper.create_mapping(gdf)
# Degree-based Voronoi mapping
mapper = stm_graph.VoronoiDegreeMapping()
district_gdf, point_to_partition = mapper.create_mapping(gdf)
# Administrative boundary mapping
mapper = stm_graph.AdministrativeMapping(districts_file="districts.shp")
district_gdf, point_to_partition = mapper.create_mapping(gdf)
Extract urban features from OpenStreetMap:
osm_features = stm_graph.extract_osm_features(
regions_gdf=district_gdf,
feature_types=['poi', 'road', 'junction'],
normalize=True
)
Build a graph representation:
graph_data = stm_graph.build_graph_and_augment(
grid_gdf=district_gdf,
points_gdf=gdf,
point_to_cell=point_to_partition,
static_features=osm_features
)
Create a temporal dataset for model training:
temporal_dataset, _, _ = stm_graph.create_temporal_dataset(
edge_index=graph_data["edge_index"],
augmented_df=graph_data["augmented_df"],
node_ids=graph_data["node_ids"],
static_features=osm_features,
time_col="timestamp",
bin_type="daily",
history_window=3,
task="classification"
)
Visualize spatial and temporal patterns:
# Plot time series
stm_graph.plot_node_time_series(
temporal_dataset,
num_nodes=5,
selection_method="highest_activity"
)
# Plot spatial network
stm_graph.plot_spatial_network(
regions_gdf=district_gdf,
edge_index=graph_data["edge_index"],
node_values=node_counts,
node_ids=graph_data["node_ids"]
)
# Plot temporal heatmap
stm_graph.plot_temporal_heatmap(
temporal_dataset,
n_steps=168
)
Train a GNN model:
# Create a model
model = stm_graph.create_model(
model_name="stgcn",
task="classification"
)
# Train the model
results = stm_graph.train_model(
model=model,
dataset=temporal_dataset,
task="classification",
num_epochs=100,
learning_rate=0.001
)
Graphical User Interface (GUI) for non-professional users is provided at STM Graph GUI Repository and can be installed from the releases section.
Contributions to STM-Graph are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
If you use STM-Graph in your research, please cite the repo and our article:
@software{stm_graph,
author = {Amirhossein Ghaffari},
title = {STM-Graph: A Python Framework for Spatio-Temporal Mapping and Graph Neural Network Predictions},
year = {2025},
publisher = {GitHub},
url = {https://github.com/Ahghaffari/stm_graph}
}
- NetworkX
- OSMnx
- PyTorch Geometric Temporal
- Weights & Biases