pytorch-geometric
Overview
PyTorch Geometric (PyG) is built on top of PyTorch to simplify the implementation of Graph Neural Networks. It treats graphs as Data objects containing node features and edge indices, and provides a powerful MessagePassing base class for custom layer development.
When to Use
Use PyG for data that is naturally represented as a graph, such as social networks, molecular structures, or point clouds. Use it when you need to perform node classification, edge prediction, or graph-level regression.
Decision Tree
- Do you have a list of small graphs?
- USE:
torch_geometric.loader.DataLoaderto create a single giant disjoint graph.
- USE:
- Do you need to pool node features into a graph-level feature?
- USE:
global_mean_poolorglobal_max_poolusing thebatchvector.
- USE:
- Are you building a custom convolution?
- INHERIT: From
torch_geometric.nn.MessagePassing.
- INHERIT: From
Workflows
-
Defining a Custom GNN Layer
- Inherit from
torch_geometric.nn.MessagePassing. - Set the aggregation scheme (
aggr='add','mean', or'max') in__init__. - Implement the forward pass using
self.propagate(). - Define the
message()function to compute the transformation for each edge. - Optionally define the
update()function to transform aggregated results.
- Inherit from
-
Mini-batching Large Graphs
- Use
torch_geometric.loader.DataLoaderinstead of the standard PyTorch version. - PyG automatically creates a single giant disjoint graph from a list of small graphs.
- The
batchvector in the resultingDataobject tracks which original graph each node belongs to. - Use global pooling (e.g.,
global_mean_pool) to aggregate node features into graph-level representations.
- Use
-
Constructing Graphs from Point Clouds
- Represent point clouds as node features in a tensor
[N, F]. - Apply
knn_graph()to dynamically compute anedge_indexbased on spatial proximity. - Pass the results into an
EdgeConvorDynamicEdgeConvlayer for feature extraction.
- Represent point clouds as node features in a tensor
Non-Obvious Insights
- Auto-Indexing: The
_iand_jnotation inMessagePassingmethods automatically handles indexing into source and destination nodes during propagation without manual slice logic. - Disjoint Representation: PyG batches multiple graphs by stacking them into a single block-diagonal adjacency matrix. This allows standard sparse matrix operations to process multiple graphs in parallel without zero-padding.
- Modular Aggregation: Aggregation is a first-class principle in PyG; users can swap simple sum/mean with advanced learnable schemes like
SoftmaxAggregationby simply changing theaggrparameter.
Evidence
- "PyG provides the MessagePassing base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation." (https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html)
- "Tensors passed to propagate() can be mapped to the respective nodes i and j by appending _i or _j to the variable name." (https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html)
Scripts
scripts/pytorch-geometric_tool.py: Template for a custom GNN layer and graph data loader.scripts/pytorch-geometric_tool.js: Node.js script to run PyG training experiments.
Dependencies
- torch
- torch-geometric
- torch-scatter / torch-sparse (optional but recommended)
References
More from cuba6112/skillfactory
ollama-rag
Build RAG systems with Ollama local + cloud models. Latest cloud models include DeepSeek-V3.2 (GPT-5 level), Qwen3-Coder-480B (1M context), MiniMax-M2. Use for document Q&A, knowledge bases, and agentic RAG. Covers LangChain, LlamaIndex, ChromaDB, and embedding models.
17unsloth-sft
Supervised fine-tuning using SFTTrainer, instruction formatting, and multi-turn dataset preparation with triggers like sft, instruction tuning, chat templates, sharegpt, alpaca, conversation_extension, and SFTTrainer.
6torchaudio
Audio signal processing library for PyTorch. Covers feature extraction (spectrograms, mel-scale), waveform manipulation, and GPU-accelerated data augmentation techniques. (torchaudio, melscale, spectrogram, pitchshift, specaugment, waveform, resample)
5pytorch-onnx
Exporting PyTorch models to ONNX format for cross-platform deployment. Includes handling dynamic axes, graph optimization in ONNX Runtime, and INT8 model quantization. (onnx, onnxruntime, torch.onnx.export, dynamic_axes, constant-folding, edge-deployment)
5unsloth-lora
Configuring and optimizing 16-bit Low-Rank Adaptation (LoRA) and Rank-Stabilized LoRA (rsLoRA) for efficient LLM fine-tuning using triggers like lora, qlora, rslora, rank selection, lora_alpha, lora_dropout, and target_modules.
4pytorch-quantization
Techniques for model size reduction and inference acceleration using INT8 quantization, including Post-Training Quantization (PTQ) and Quantization Aware Training (QAT). (quantization, int8, qat, fbgemm, qnnpack, ptq, dequantize)
3