Skip to content

Commit be8454d

Browse files
authored
fix: base and trainer save & from pretrained
1 parent 7639e28 commit be8454d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+708
-482
lines changed

.github/workflows/codeql-analysis.yml

Lines changed: 0 additions & 52 deletions
This file was deleted.

.github/workflows/lint.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ jobs:
1212
runs-on: ubuntu-latest
1313
steps:
1414
- name: Checkout
15-
uses: actions/checkout@v2
15+
uses: actions/checkout@v4
1616
- name: Set up Python 3.8
17-
uses: actions/setup-python@v2
17+
uses: actions/setup-python@v5
1818
with:
1919
python-version: 3.8
2020
- name: Install Black
@@ -27,9 +27,9 @@ jobs:
2727
runs-on: ubuntu-latest
2828
steps:
2929
- name: Checkout
30-
uses: actions/checkout@v2
30+
uses: actions/checkout@v4
3131
- name: Set up Python 3.8
32-
uses: actions/setup-python@v2
32+
uses: actions/setup-python@v5
3333
with:
3434
python-version: 3.8
3535
- name: Install isort
@@ -43,8 +43,8 @@ jobs:
4343
timeout-minutes: 10
4444
steps:
4545
- name: Checkout
46-
uses: actions/checkout@v2
47-
- uses: actions/setup-python@v2
46+
uses: actions/checkout@v4
47+
- uses: actions/setup-python@v5
4848
with:
4949
python-version: 3.8
5050

@@ -59,8 +59,8 @@ jobs:
5959
timeout-minutes: 10
6060
steps:
6161
- name: Checkout
62-
uses: actions/checkout@v2
63-
- uses: actions/setup-python@v2
62+
uses: actions/checkout@v4
63+
- uses: actions/setup-python@v5
6464
with:
6565
python-version: 3.8
6666
- name: Install dependencies

.github/workflows/test.yml

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,25 +45,17 @@ jobs:
4545
shell: bash
4646
run: poetry config virtualenvs.in-project true
4747

48-
- name: Set up cache
49-
uses: actions/cache@v4
50-
id: cache
51-
with:
52-
path: .venv
53-
key: venv-${{ runner.os }}-${{ steps.full-python-version.outputs.version }}-${{ hashFiles('**/poetry.lock') }}
54-
55-
- name: Ensure cache is healthy
56-
if: steps.cache.outputs.cache-hit == 'true'
57-
shell: bash
58-
run: poetry run pip --version >/dev/null 2>&1 || rm -rf .venv
59-
60-
- name: Upgrade pip
48+
- name: Check for healthy virtualenv manually
6149
shell: bash
62-
run: poetry run python -m pip install pip -U
50+
run: |
51+
if [ -d .venv ]; then
52+
poetry run pip --version >/dev/null 2>&1 || rm -rf .venv
53+
fi
6354
6455
- name: Install dependencies
6556
shell: bash
6657
run: |
58+
poetry run python -m pip install pip -U
6759
poetry install --no-interaction --no-root
6860
poetry run python -m pip install tensorflow==${{ matrix.tf-version }}
6961
poetry run python -m pip install matplotlib
@@ -102,13 +94,11 @@ jobs:
10294
with:
10395
python-version: '3.10'
10496

105-
- name: Cache pip
106-
uses: actions/cache@v4
107-
with:
108-
path: ~/.cache/pip
109-
key: ${{ runner.os }}-pip-${{ hashFiles('docs/requirements_docs.txt') }}
110-
restore-keys: |
111-
${{ runner.os }}-pip-
97+
- name: Create pip cache directory manually
98+
run: |
99+
CACHE_DIR="$HOME/.cache/pip"
100+
mkdir -p "$CACHE_DIR"
101+
echo "Using manual pip cache at $CACHE_DIR"
112102
113103
- name: Install dependencies
114104
run: |

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33

44
repos:
55
- repo: https://github.com/pre-commit/pre-commit-hooks
6-
rev: v4.6.0
6+
rev: v4.2.0
77
hooks:
88
- id: trailing-whitespace
99
- id: end-of-file-fixer
1010
- id: check-yaml
1111
- id: check-ast
1212
- repo: https://github.com/PyCQA/flake8
13-
rev: "3.9.2"
13+
rev: "5.0.4"
1414
hooks:
1515
- id: flake8
1616
exclude: 'tests|.venv|docs'

Makefile

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
.PHONY: style test docs pre-release
1+
.PHONY: style test docs pre-release help
22

3-
check_dirs := tfts examples tests
3+
# Directories to run style checks on
4+
CHECK_DIRS := tfts examples tests
45

5-
# run checks on all files and potentially modifies some of them
6-
7-
style:
8-
black $(check_dirs)
9-
isort $(check_dirs)
10-
flake8
6+
## Format code and run linting tools
7+
style: ## Run formatters and linters (black, isort, flake8, pre-commit)
8+
black $(CHECK_DIRS)
9+
isort $(CHECK_DIRS)
10+
flake8 $(check_dirs)
1111
pre-commit run --all-files
1212

13-
# run tests for the library
14-
15-
test:
16-
python -m unittest
13+
## Run all unit tests
14+
test: ## Run unit tests using unittest
15+
python -m unittest discover
1716

18-
# run tests for the docs
19-
20-
docs:
17+
## Build the documentation
18+
docs: ## Build HTML documentation using Sphinx
2119
make -C docs clean M=$(shell pwd)
2220
make -C docs html M=$(shell pwd)
21+
22+
## Display help for make targets
23+
help: ## Show this help
24+
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[33m<target>\033[0m\n\nTargets:\n"} /^[a-zA-Z\/_-]+:.*?##/ { printf " \033[36m%-20s\033[0m %s\n", $$1, $$2 }' $(MAKEFILE_LIST)

codecov.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ coverage:
55
status:
66
project:
77
default:
8-
threshold: 1%
8+
threshold: 3%
99

1010
patch:
1111
default:

docker/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
From tensorflow/tensorflow:2.8.3-gpu
1+
From tensorflow/tensorflow:2.16.1-gpu
22

33
RUN apt-get update
44
RUN apt-get install -y libgl1-mesa-dev wget vim python3.8

docs/requirements_docs.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ cloudpickle
1212

1313
pandas >= 1.3
1414
numpy < 2
15-
tensorflow==2.10.0
15+
tensorflow==2.16.1
1616
matplotlib
1717
optuna>=2.0
1818
scikit-learn>0.23

docs/source/_static/logo.svg

Lines changed: 1 addition & 1 deletion
Loading

docs/source/index.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,30 @@ The tfts library supports the SOTA deep learning models for time series.
4444
- `TFTS Seq2Seq model <https://github.com/LongxingTan/Data-competitions/tree/master/tianchi-enso-prediction>`_ wins the 4th place in `Alibaba Tianchi ENSO prediction <https://tianchi.aliyun.com/competition/entrance/531871/introduction>`_
4545
- :ref:`Learn more models <models>`
4646

47+
.. code-block:: python
48+
49+
import tensorflow as tf
50+
from tfts import AutoConfig, AutoModel
51+
52+
53+
def build_model(use_model, input_shape):
54+
inputs = tf.keras.layers.Input(input_shape)
55+
config = AutoConfig.for_model(use_model)
56+
57+
backbone = AutoModel.from_config(config)
58+
outputs = backbone(inputs)
59+
model = tf.keras.Model(inputs, outputs=outputs)
60+
61+
optimizer = tf.keras.optimizers.Adam(0.003)
62+
loss_fn = tf.keras.losses.MeanSquaredError()
63+
64+
model.compile(optimizer, loss_fn)
65+
return model
66+
67+
68+
model = build_model(use_model="bert", input_shape=(24, 3))
69+
model.summary()
70+
4771
4872
Tricks
4973
-------

docs/source/models.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ You can you below models in ``AutoModel``
1919
* WaveNet
2020
* Bert
2121
* Transformer
22-
* Dlinear
22+
* DLinear
2323
* NBeats
2424
* AutoFormer
2525
* Informer

0 commit comments

Comments
 (0)