From ed8dd536ed024e0b559ca58e35d9e7de5692c5e8 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 9 Feb 2025 10:39:22 +0000 Subject: [PATCH 1/6] update --- .github/actions/setup/action.yml | 2 +- .github/workflows/linting.yml | 3 +++ .github/workflows/testing.yml | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/actions/setup/action.yml b/.github/actions/setup/action.yml index c434d0d4..20b6b2bf 100644 --- a/.github/actions/setup/action.yml +++ b/.github/actions/setup/action.yml @@ -6,7 +6,7 @@ inputs: default: '3.9' torch-version: required: false - default: '2.5' + default: '2.6' cuda-version: required: false default: cpu diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 80a691d3..652a1a9d 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -12,13 +12,16 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + - name: Set up Python uses: actions/setup-python@v5 with: python-version: 3.9 + - name: Install dependencies run: | pip install -e '.[full,test]' -f https://download.pytorch.org/whl/cpu pip list + - name: Check type hints run: mypy diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index b7fe56ec..aec0982b 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -24,6 +24,7 @@ jobs: - '2.3' - '2.4' - '2.5' + - '2.6' - 'nightly' steps: From 85eca16c4a470b985d96c3f02b4fa82478616a10 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 9 Feb 2025 10:40:48 +0000 Subject: [PATCH 2/6] update --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35eddc3d..88e01f97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for PyTorch 2.6 ([#494](https://github.com/pyg-team/pytorch-frame/pull/494)) + ### Changed ### Deprecated From 5210ccde229d6c24e6c4c08d4556481f141f8b88 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 9 Feb 2025 10:43:07 +0000 Subject: [PATCH 3/6] . --- .github/workflows/testing.yml | 2 +- .pre-commit-config.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index aec0982b..f126602e 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -1,4 +1,4 @@ -name: Testing PyTorch 2.5 +name: Testing on: # yamllint disable-line rule:truthy push: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 009b3c52..601b9c12 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -85,7 +85,7 @@ repos: hooks: - id: mypy name: Check types - additional_dependencies: [torch==2.5.0] + additional_dependencies: [torch==2.6.*] exclude: "^test/|^examples/|^benchmark/" - repo: https://github.com/executablebooks/mdformat From 996b667b9e2015edb186779b402472b16911edd7 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 9 Feb 2025 11:23:49 +0000 Subject: [PATCH 4/6] fix types --- docs/requirements.txt | 2 +- torch_frame/nn/conv/excelformer_conv.py | 1 + torch_frame/nn/encoding/cyclic_encoding.py | 7 +++++-- torch_frame/nn/encoding/positional_encoding.py | 11 ++++++++--- torch_frame/nn/models/tabnet.py | 10 ++++++++-- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 907cd381..2cd365c0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,2 @@ -https://download.pytorch.org/whl/cpu/torch-2.5.0%2Bcpu-cp39-cp39-linux_x86_64.whl +https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp39-cp39-linux_x86_64.whl git+https://github.com/pyg-team/pyg_sphinx_theme.git diff --git a/torch_frame/nn/conv/excelformer_conv.py b/torch_frame/nn/conv/excelformer_conv.py index a897fa2b..d9d88ade 100644 --- a/torch_frame/nn/conv/excelformer_conv.py +++ b/torch_frame/nn/conv/excelformer_conv.py @@ -58,6 +58,7 @@ def __init__(self, channels: int, num_cols: int, num_heads: int, self.lin_out = Linear(channels, channels) if num_heads > 1 else None self.num_heads = num_heads self.dropout = Dropout(dropout) + self.seq_ids: Tensor self.register_buffer('seq_ids', torch.arange(num_cols)) self.reset_parameters() diff --git a/torch_frame/nn/encoding/cyclic_encoding.py b/torch_frame/nn/encoding/cyclic_encoding.py index 70b3981a..11317f42 100644 --- a/torch_frame/nn/encoding/cyclic_encoding.py +++ b/torch_frame/nn/encoding/cyclic_encoding.py @@ -23,8 +23,11 @@ def __init__(self, out_size: int) -> None: raise ValueError( f"out_size should be divisible by 2 (got {out_size}).") self.out_size = out_size - mult_term = torch.arange(1, self.out_size // 2 + 1) - self.register_buffer("mult_term", mult_term) + self.mult_term: Tensor + self.register_buffer( + "mult_term", + torch.arange(1, self.out_size // 2 + 1), + ) def forward(self, input_tensor: Tensor) -> Tensor: assert torch.all((input_tensor >= 0) & (input_tensor <= 1)) diff --git a/torch_frame/nn/encoding/positional_encoding.py b/torch_frame/nn/encoding/positional_encoding.py index 7b7de732..eb451e13 100644 --- a/torch_frame/nn/encoding/positional_encoding.py +++ b/torch_frame/nn/encoding/positional_encoding.py @@ -18,9 +18,14 @@ def __init__(self, out_size: int) -> None: raise ValueError( f"out_size should be divisible by 2 (got {out_size}).") self.out_size = out_size - mult_term = torch.pow(1 / 10000.0, - torch.arange(0, self.out_size, 2) / out_size) - self.register_buffer("mult_term", mult_term) + self.mult_term: Tensor + self.register_buffer( + "mult_term", + torch.pow( + 1 / 10000.0, + torch.arange(0, self.out_size, 2) / out_size, + ), + ) def forward(self, input_tensor: Tensor) -> Tensor: assert torch.all(input_tensor >= 0) diff --git a/torch_frame/nn/models/tabnet.py b/torch_frame/nn/models/tabnet.py index db14065f..7916986d 100644 --- a/torch_frame/nn/models/tabnet.py +++ b/torch_frame/nn/models/tabnet.py @@ -257,10 +257,16 @@ def forward(self, x: Tensor) -> Tensor: return x def reset_parameters(self) -> None: + # TODO: Remove this type cast when PyTorch fixes typing issue where + # reset_parameters is typed as: + # Union[torch._tensor.Tensor, torch.nn.modules.module.Module] + # This issue was first observed on PyTorch 2.6.0. if not isinstance(self.shared_glu_block, Identity): - self.shared_glu_block.reset_parameters() + from typing import Callable, cast + cast(Callable, self.shared_glu_block.reset_parameters)() if not isinstance(self.dependent, Identity): - self.dependent.reset_parameters() + from typing import Callable, cast + cast(Callable, self.dependent.reset_parameters)() class GLUBlock(Module): From 395eed837755a379056f2d8046720ab46fd9c611 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 9 Feb 2025 11:24:23 +0000 Subject: [PATCH 5/6] fix types --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 601b9c12..a62d539b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,7 +81,7 @@ repos: args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.14.1 + rev: v1.15.0 hooks: - id: mypy name: Check types From 8fb9e3ff7f68ebde7d987567ae7e97b81ef5f595 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 9 Feb 2025 11:28:59 +0000 Subject: [PATCH 6/6] fix types --- torch_frame/nn/models/tabnet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torch_frame/nn/models/tabnet.py b/torch_frame/nn/models/tabnet.py index 7916986d..406c6e7d 100644 --- a/torch_frame/nn/models/tabnet.py +++ b/torch_frame/nn/models/tabnet.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from typing import Any +from typing import Any, Callable, cast import torch import torch.nn.functional as F @@ -262,10 +262,8 @@ def reset_parameters(self) -> None: # Union[torch._tensor.Tensor, torch.nn.modules.module.Module] # This issue was first observed on PyTorch 2.6.0. if not isinstance(self.shared_glu_block, Identity): - from typing import Callable, cast cast(Callable, self.shared_glu_block.reset_parameters)() if not isinstance(self.dependent, Identity): - from typing import Callable, cast cast(Callable, self.dependent.reset_parameters)()