Skip to content

Commit 8925aca

Browse files
committed
Init
1 parent f179087 commit 8925aca

34 files changed

+3028
-1
lines changed

.github/dependabot.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# To get started with Dependabot version updates, you'll need to specify which
2+
# package ecosystems to update and where the package manifests are located.
3+
# Please see the documentation for all configuration options:
4+
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5+
6+
version: 2
7+
updates:
8+
- package-ecosystem: "github-actions"
9+
directory: "/"
10+
schedule:
11+
interval: "monthly"
12+
groups:
13+
github-actions:
14+
patterns:
15+
- "*"
16+
- package-ecosystem: "pip"
17+
directory: "/"
18+
schedule:
19+
interval: "monthly"
20+
groups:
21+
python:
22+
patterns:
23+
- "*"

.github/workflows/actions.yml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Ref: https://github.com/keras-team/keras/blob/master/.github/workflows/actions.yml
2+
name: Tests
3+
4+
on:
5+
push:
6+
branches: [ master ]
7+
pull_request:
8+
release:
9+
types: [created]
10+
11+
permissions:
12+
contents: read
13+
14+
jobs:
15+
build:
16+
strategy:
17+
fail-fast: false
18+
matrix:
19+
python-version: [3.9]
20+
backend: [tensorflow, jax, torch, numpy]
21+
name: Run tests
22+
runs-on: ubuntu-latest
23+
env:
24+
PYTHON: ${{ matrix.python-version }}
25+
steps:
26+
- uses: actions/checkout@v4
27+
- name: Set up Python
28+
uses: actions/setup-python@v5
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
- name: Get pip cache dir
32+
id: pip-cache
33+
run: |
34+
python -m pip install --upgrade pip setuptools
35+
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
36+
- name: pip cache
37+
uses: actions/cache@v3
38+
with:
39+
path: ${{ steps.pip-cache.outputs.dir }}
40+
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements.txt') }}
41+
- name: Install dependencies
42+
run: |
43+
pip install -r requirements.txt --progress-bar off --upgrade
44+
pip install -e ".[tests]" --progress-bar off --upgrade
45+
- name: Test with pytest
46+
run: |
47+
pytest
48+
49+
format:
50+
name: Check the code format
51+
runs-on: ubuntu-latest
52+
steps:
53+
- uses: actions/checkout@v4
54+
- name: Set up Python 3.9
55+
uses: actions/setup-python@v5
56+
with:
57+
python-version: '3.9'
58+
- name: Get pip cache dir
59+
id: pip-cache
60+
run: |
61+
python -m pip install --upgrade pip setuptools
62+
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
63+
- name: pip cache
64+
uses: actions/cache@v3
65+
with:
66+
path: ${{ steps.pip-cache.outputs.dir }}
67+
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements.txt') }}
68+
- name: Install dependencies
69+
run: |
70+
pip install -r requirements.txt --progress-bar off --upgrade
71+
pip install -e ".[tests]" --progress-bar off --upgrade
72+
- name: Lint
73+
run: bash shell/lint.sh

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,7 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161+
162+
# Keras
163+
*.keras
164+
exported

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,10 @@
1-
# kimm
1+
# Keras Image Models
2+
3+
## Unit Tests
4+
5+
```bash
6+
# KERAS_BACKEND=jax|numpy|tensorflow|torch
7+
CUDA_VISIBLE_DEVICES= KERAS_BACKEND=tensorflow pytest
8+
```
9+
10+
## Acknowledgments

conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import os
2+
3+
4+
def pytest_configure():
5+
# disable jax gpu memory preallocation
6+
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
7+
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

kimm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.1.0"

kimm/blocks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kimm.blocks.base_block import apply_activation
2+
from kimm.blocks.base_block import apply_conv2d_block
3+
from kimm.blocks.base_block import apply_se_block

kimm/blocks/base_block.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from keras import layers
2+
3+
from kimm.utils import make_divisible
4+
5+
6+
def apply_activation(x, activation=None, name="activation"):
7+
if activation is not None:
8+
if isinstance(activation, str):
9+
x = layers.Activation(activation, name=name)(x)
10+
elif isinstance(activation, layers.Layer):
11+
x = activation(x)
12+
else:
13+
NotImplementedError(
14+
f"Unsupported activation type: {type(activation)}"
15+
)
16+
return x
17+
18+
19+
def apply_conv2d_block(
20+
inputs,
21+
filters,
22+
kernel_size,
23+
strides=1,
24+
groups=1,
25+
activation=None,
26+
use_depthwise=False,
27+
bn_momentum=0.9,
28+
bn_epsilon=1e-5,
29+
name="conv2d_block",
30+
):
31+
x = inputs
32+
33+
padding = "same"
34+
if strides > 1:
35+
padding = "valid"
36+
x = layers.ZeroPadding2D(kernel_size // 2, name=f"{name}_pad")(x)
37+
38+
if not use_depthwise:
39+
x = layers.Conv2D(
40+
filters,
41+
kernel_size,
42+
strides,
43+
padding=padding,
44+
groups=groups,
45+
use_bias=False,
46+
name=f"{name}_conv2d",
47+
)(x)
48+
else:
49+
x = layers.DepthwiseConv2D(
50+
kernel_size,
51+
strides,
52+
padding=padding,
53+
use_bias=False,
54+
name=f"{name}_dwconv2d",
55+
)(x)
56+
x = layers.BatchNormalization(
57+
name=f"{name}_bn", momentum=bn_momentum, epsilon=bn_epsilon
58+
)(x)
59+
x = apply_activation(x, activation, name=name)
60+
return x
61+
62+
63+
def apply_se_block(
64+
inputs,
65+
se_ratio=0.25,
66+
activation="relu",
67+
gate_activation="sigmoid",
68+
make_divisible_number=None,
69+
name="se_block",
70+
):
71+
input_channels = inputs.shape[-1]
72+
if make_divisible_number is None:
73+
se_channels = round(input_channels * se_ratio)
74+
else:
75+
se_channels = make_divisible(
76+
input_channels * se_ratio, make_divisible_number
77+
)
78+
79+
ori_x = inputs
80+
x = inputs
81+
x = layers.GlobalAveragePooling2D(keepdims=True, name=f"{name}_mean")(x)
82+
x = layers.Conv2D(
83+
se_channels, 1, use_bias=True, name=f"{name}_reduce_conv2d"
84+
)(x)
85+
x = apply_activation(x, activation, name=f"{name}_act")
86+
x = layers.Conv2D(
87+
input_channels, 1, use_bias=True, name=f"{name}_expand_conv2d"
88+
)(x)
89+
x = apply_activation(x, gate_activation, name=f"{name}_gate_act")
90+
out = layers.Multiply(name=name)([ori_x, x])
91+
return out

kimm/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from kimm.layers.attention import Attention
2+
from kimm.layers.layer_scale import LayerScale
3+
from kimm.layers.position_embedding import PositionEmbedding

kimm/layers/attention.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from keras import layers
2+
from keras import ops
3+
4+
5+
class Attention(layers.Layer):
6+
def __init__(
7+
self,
8+
hidden_dim,
9+
num_heads=8,
10+
use_qkv_bias=False,
11+
use_qk_norm=False,
12+
attention_dropout_rate=0.0,
13+
projection_dropout_rate=0.0,
14+
name="attention",
15+
**kwargs,
16+
):
17+
super().__init__(**kwargs)
18+
self.hidden_dim = hidden_dim
19+
self.num_heads = num_heads
20+
self.head_dim = hidden_dim // num_heads
21+
self.scale = self.head_dim ** (-0.5)
22+
self.use_qkv_bias = use_qkv_bias
23+
self.use_qk_norm = use_qk_norm
24+
self.attention_dropout_rate = attention_dropout_rate
25+
self.projection_dropout_rate = projection_dropout_rate
26+
self.name = name
27+
28+
self.qkv = layers.Dense(
29+
hidden_dim * 3, use_bias=use_qkv_bias, name=f"{name}_qkv"
30+
)
31+
if use_qk_norm:
32+
self.q_norm = layers.LayerNormalization(name=f"{name}_q_norm")
33+
self.k_norm = layers.LayerNormalization(name=f"{name}_k_norm")
34+
else:
35+
self.q_norm = layers.Identity()
36+
self.k_norm = layers.Identity()
37+
38+
self.attention_dropout = layers.Dropout(
39+
attention_dropout_rate, name=f"{name}_attn_drop"
40+
)
41+
self.projection = layers.Dense(hidden_dim, name=f"{name}_proj")
42+
self.projection_dropout = layers.Dropout(
43+
projection_dropout_rate, name=f"{name}_proj_drop"
44+
)
45+
46+
def call(self, inputs, training=None, mask=None):
47+
input_shape = ops.shape(inputs)
48+
qkv = self.qkv(inputs)
49+
qkv = ops.reshape(
50+
qkv,
51+
[
52+
input_shape[0],
53+
input_shape[1],
54+
3,
55+
self.num_heads,
56+
self.head_dim,
57+
],
58+
)
59+
qkv = ops.transpose(qkv, [2, 0, 3, 1, 4])
60+
q, k, v = ops.unstack(qkv, 3, axis=0)
61+
q = self.q_norm(q)
62+
k = self.k_norm(k)
63+
64+
# attention
65+
q = ops.multiply(q, self.scale)
66+
attn = ops.matmul(q, ops.swapaxes(k, -2, -1))
67+
attn = ops.softmax(attn)
68+
attn = self.attention_dropout(attn)
69+
x = ops.matmul(attn, v)
70+
71+
x = ops.swapaxes(x, 1, 2)
72+
x = ops.reshape(x, input_shape)
73+
x = self.projection(x)
74+
x = self.projection_dropout(x)
75+
return x
76+
77+
def get_config(self):
78+
config = super().get_config()
79+
config.update(
80+
{
81+
"hidden_dim": self.hidden_dim,
82+
"num_heads": self.num_heads,
83+
"use_qkv_bias": self.use_qkv_bias,
84+
"use_qk_norm": self.use_qk_norm,
85+
"attention_dropout_rate": self.attention_dropout_rate,
86+
"projection_dropout_rate": self.projection_dropout_rate,
87+
"name": self.name,
88+
}
89+
)
90+
return config
91+
92+
93+
if __name__ == "__main__":
94+
from keras import models
95+
from keras import random
96+
97+
inputs = layers.Input(shape=[197, 768])
98+
outputs = Attention(768)(inputs)
99+
100+
model = models.Model(inputs, outputs)
101+
model.summary()
102+
103+
inputs = random.uniform([1, 197, 768])
104+
outputs = model(inputs)
105+
print(outputs.shape)

0 commit comments

Comments
 (0)