Skip to content

Commit 579f717

Browse files
authored
MySmallModel
1 parent 0b5cbad commit 579f717

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# My Model... ya...
2+
3+
This repository contains a minimal implementation of a GPT-style chatbot from scratch using PyTorch. It supports:
4+
5+
* **Data loading**: extraction from PDF files, tokenization, vocabulary building.
6+
* **Model**: a lightweight Transformer-based GPT implemented with `nn.TransformerEncoder`.
7+
* **Training**: training loop with checkpoint saving.
8+
* **Generation**: simple text generation utility.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch>=1.13.0
2+
numpy
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
import torch
3+
import torch.nn.functional as F
4+
from torch.optim import Adam
5+
from data_loader import get_loader
6+
from model import GPT
7+
from config import Config
8+
9+
# Prepare data
10+
loader, vocab = get_loader(vocab_path=Config.VOCAB_PATH)
11+
12+
# Initialize model
13+
model = GPT().to(Config.device)
14+
optimizer = Adam(model.parameters(), lr=Config.lr)
15+
16+
# Training loop
17+
for epoch in range(1, Config.epochs + 1):
18+
model.train()
19+
total_loss = 0.0
20+
for x, y in loader:
21+
x = x.to(Config.device)
22+
y = y.to(Config.device)
23+
logits = model(x)
24+
loss = F.cross_entropy(logits.view(-1, Config.vocab_size), y.view(-1))
25+
optimizer.zero_grad()
26+
loss.backward()
27+
optimizer.step()
28+
total_loss += loss.item()
29+
avg_loss = total_loss / len(loader)
30+
print(f"Epoch {epoch}/{Config.epochs}, Loss: {avg_loss:.4f}")
31+
32+
# Save checkpoint
33+
os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
34+
ckpt_path = os.path.join(Config.OUTPUT_DIR, f"model_epoch{epoch}.pt")
35+
torch.save(model.state_dict(), ckpt_path)
36+
print(f"Saved checkpoint: {ckpt_path}")

0 commit comments

Comments
 (0)