# gnn-mp-kernel **Repository Path**: chen_lin_k/gnn-mp-kernel ## Basic Information - **Project Name**: gnn-mp-kernel - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2026-05-01 - **Last Updated**: 2026-05-02 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # GNN Message Passing — Fused Kernel Benchmark ## Overview Comparison of two approaches for GNN message-passing edge-feature aggregation: - **GAS (Gather-Apply-Scatter)**: Standard PyTorch baseline - **Fused**: Custom Triton kernel, single-pass atomic-add ## Operation For each edge `i (src → dst)`: ``` dst_feat[dst_idx[i], :] += src_feat[src_idx[i], :] * edge_w[i, :] ``` ### Shapes | Tensor | Shape | Description | |--------|-------|-------------| | `src_feat` | `[N, D]` | Node features (QKV-mapped V) | | `edge_w` | `[E, D]` | Edge weights (attn × dv) | | `src_idx` | `[E,]` int64 | Source node per edge | | `dst_idx` | `[E,]` int64 | Destination node per edge | | `output` | `[N, D]` | Aggregated result | ### Typical Scale - N = 4096 nodes (~4000 atoms) - E = 200k edges - D = 64 ~ 4096 features --- ## File Structure ``` gnn-mp-kernel/ ├── fused_mp.py # Triton fused kernel + launcher ├── benchmark.py # GAS baseline + benchmark suite ├── profile_target.py # Entry point for ncu/nsys profiling (arg: gas|fused) ├── profile_analysis.py # PyTorch profiler + bandwidth estimation ├── RESULTS.md # Benchmark results ├── results.csv # Raw data └── .gitignore ``` --- ## Implementations ### Baseline: Gather-Apply-Scatter File: `benchmark.py:7-27` ```python src_gathered = src_feat[src_idx] # [N,D] → [E,D] (GATHER) msg = src_gathered * edge_w # [E,D] × [E,D] (APPLY) out = torch.zeros(N, D, device='cuda') out.scatter_add_(0, index, msg) # [E,D] → [N,D] (SCATTER) ``` Creates two `[E, D]` intermediate tensors (`src_gathered`, `msg`). Total memory traffic = ~3× `E × D` plus overhead for indices. **PyTorch kernels launched per iteration:** | Kernel | Operation | Grid | Regs | Duration (D=1024) | |--------|-----------|------|------|-------------------| | `index_elementwise_kernel` | Gather | `(400000,1,1)` | 40 | 5.10 ms | | `vectorized_elementwise_kernel` | Multiply | `(400000,1,1)` | 22 | 7.31 ms | | `vectorized_elementwise_kernel` | Zero init | `(8192,1,1)` | 16 | 0.05 ms | | `_scatter_gather_elementwise_kernel<_cuda_scatter_add>` | ScatterAdd | `(400000,1,1)` | 32 | 7.06 ms | Grid = `400000` = `E × D / vector_width / block_x` = `200000 × 1024 / 4 / 128`. ### Fused Kernel (Triton) File: `fused_mp.py:6-62` **Grid:** `(E, ceil(D / BLOCK_D))` — 2D, edge × feature-block. **Each program (one edge × one D-block):** ``` src_id = load(src_idx[edge_id]) dst_id = load(dst_idx[edge_id]) src_val = load(src_feat[src_id, d_offsets]) edge_val = load(edge_w[edge_id, d_offsets]) result = src_val * edge_val atomic_add(dst_feat[dst_id, d_offsets], result) ``` **Kernel launched per iteration:** | Kernel | Grid | Block | Regs | Duration | |--------|------|-------|------|----------| | `vectorized_elementwise_kernel` | `(8192,1,1)` | 128 | 16 | 0.05 ms | | `_fused_mp_kernel` | `(200000, 16, 1)` | 128 | 16 | 11.78 ms | Grid `(200000, 16, 1)` = `E × ceil(D/BLOCK_D)` = `200000 × 16` (D=1024, BLOCK_D=64). No intermediate `[E, D]` tensors. Total memory traffic ≈ `1× E × D` plus indices. --- ## Benchmark Results ### Speedup by D (E=200k fixed) | D | GAS (ms) | Fused (ms) | Speedup | GAS (MB) | Fused (MB) | Mem save | |---|----------|------------|---------|----------|------------|----------| | 64 | 1.01 | 0.80 | 1.26× | 152.9 | 54.9 | 64.1% | | 128 | 2.11 | 1.43 | 1.47× | 303.1 | 107.1 | 64.7% | | 256 | 4.54 | 2.74 | 1.66× | 603.1 | 211.1 | 65.0% | | 512 | 9.52 | 5.36 | 1.78× | 1199.9 | 418.6 | 65.1% | | 1024 | 19.48 | 10.63 | 1.83× | 2398.0 | 834.0 | 65.2% | | 2048 | 39.24 | 21.19 | 1.85× | 4786.6 | 1661.6 | 65.3% | | 4096 | 78.78 | 42.23 | 1.87× | 9573.1 | 3321.1 | 65.3% | ### Speedup by E (D=128 fixed) | E | GAS (ms) | Fused (ms) | Speedup | GAS (MB) | Fused (MB) | Mem save | |---|----------|------------|---------|----------|------------|----------| | 50k | 0.56 | 0.43 | 1.30× | 80.0 | 31.2 | 61.0% | | 100k | 1.07 | 0.76 | 1.41× | 154.0 | 56.4 | 63.4% | | 200k | 2.08 | 1.42 | 1.47× | 302.4 | 107.1 | 64.6% | | 500k | 5.13 | 3.41 | 1.51× | 746.1 | 257.8 | 65.4% | ### Full Sweep — Speedup Matrix | E \ D | 64 | 128 | 256 | 512 | 1024 | 2048 | 4096 | |-------|-----|------|------|------|------|------|------| | 50k | 1.02× | 1.30× | 1.55× | 1.72× | 1.81× | 1.84× | 1.85× | | 100k | 1.19× | 1.41× | 1.62× | 1.76× | 1.83× | 1.85× | 1.86× | | 200k | 1.26× | 1.47× | 1.66× | 1.78× | 1.83× | 1.85× | 1.87× | | 500k | 1.35× | 1.51× | 1.69× | 1.80× | 1.84× | 1.86× | OOM* | \* GAS OOM at E=500k, D=4096 (edge_w alone = 8 GB + intermediates ~16 GB > 12 GB VRAM). Fused kernel fits (~6.6 GB). --- ## Analysis ### Memory - **Peak saving: ~65% stable across all configs.** Eliminates `[E, D] × 2` intermediates. - **GAS traffic:** ~4.94 GB per iteration (D=1024). Composed of: - Gather: read `[E,D]` src_feat via index + write `[E,D]` gathered = ~1.64 GB - Apply: read gathered + edge_w + write msg = ~2.46 GB - Scatter: read msg + write `[N,D]` output = ~0.84 GB - Of 4.94 GB, ~3.28 GB (66%) is intermediate reads/writes eliminated by fusion. ### Compute - **Both GAS and Fused are memory-bound.** Arithmetic intensity: 0.08 (GAS) and 0.25 (Fused) FLOPs/Byte. Ridge point of RTX 3060 is ~35 FLOPs/Byte. - **GAS Apply kernel achieves ~93% DRAM bandwidth utilization** (337 GB/s of 360 GB/s peak). PyTorch's vectorized elementwise kernels are near-optimal. - **GAS ScatterAdd is the bottleneck:** random write pattern causes L2 cache thrashing. Effective bandwidth ~120 GB/s. - **Fused kernel achieves ~41% BW utilization** (148 GB/s). The gap vs GAS's 70% overall is due to atomic-add contention — multiple edges writing to the same destination serialize at the memory controller. ### Speedup Scaling Speedup grows with D because: - Atomic-add contention is per-address: larger D spreads writes across more addresses, reducing collision probability. - Compute density per edge scales with D, better amortizing kernel launch overhead. If fused kernel could match GAS's 70% BW utilization, theoretical speedup would be ~3× (from 3× traffic reduction). Current ~1.87× gap is explained by atomic contention. --- ## Profiling Tools ### Environment Constraints - **ncu (Nsight Compute):** Unavailable. Root cause: `RmProfilingAdminOnly: 1` in NVIDIA kernel module. This locks GPU performance counters (PMU) at driver level. Container has read-only filesystem, cannot reload module with `NVreg_RestrictProfilingToAdmin=0`. Even `sudo` cannot bypass — restriction is in the loaded kernel module. - **nsys (Nsight Systems):** Available at `/opt/nvidia/nsight-compute/2023.1.1/host/target-linux-x64/nsys`. Uses CUPTI tracing API, not PMU counters. Can capture: kernel timeline, grid/block dimensions, register count, duration. Cannot capture: DRAM bandwidth, cache hit rates, warp stall reasons, occupancy. - **nvprof:** Not supported on Ampere (CC 8.x+) GPUs. ### nsys Methodology ```bash # Profile (CUDA trace only, 5s cap) nsys profile --trace=cuda --duration=5 -o output_name python3 script.py # Export kernel trace nsys stats --report gputrace output_name.nsys-rep ``` Analyze the tabular output: each row is a kernel event, sorted by `Start(ns)`. The 4-column key table (`GrdX`, `BlkX`, `Reg/Trd`, `Duration`) reveals each kernel's purpose and cost. ### PyTorch Profiler (Fallback for BW Estimation) ```python from torch.profiler import profile, ProfilerActivity with profile(activities=[ProfilerActivity.CUDA]) as prof: out = model() torch.cuda.synchronize() # Access kernel times via prof.key_averages() ``` Use `device_time_total` attribute in PyTorch 2.4+. DRAM bandwidth must be estimated analytically: `bytes = Σ(tensor_shapes × dtype_size × num_accesses)`, then `BW = bytes / kernel_time`. --- ## Key Insights Summary 1. **GAS is memory-bound, not compute-bound.** Individual PyTorch kernels are well-optimized (Apply hits 93% peak BW). The waste is 66% of total traffic going to intermediate tensors. 2. **Fused kernel reduces traffic by 66%, but BW utilization drops from 70% to 41%** due to atomic-add contention on random write patterns. 3. **Speedup increases with D**, as larger feature dimensions amortize contention and kernel launch overhead. 4. **The gap to ideal (~3×) lies in atomic optimization.** Potential improvements: sort edges by destination for coalesced atomics, use shared-memory warp-level reduction before global atomic, or partition edges by destination to eliminate conflicts. 5. **D=64 small-D case** is slower with fusion (0.83×). BLOCK_D=64 wastes half the threads. Need adaptive BLOCK_D or alternate strategy for small D.