Skip to content

added_preparing_data file #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added __pycache__/preparing_data.cpython-38.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions captions.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
a happy dog
a big red house
grasper retract peritoneum
bipolar coagulate blood_vesse
1 change: 1 addition & 0 deletions dict.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000000.png": [["grasper,retract,gallbladder"], ["grasper,retract,gut"], ["hook,dissect,omentum"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000001.png": [["grasper,retract,omentum"], ["hook,dissect,omentum"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000002.png": [["grasper,retract,gallbladder"], ["hook,dissect,cystic_plate"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000003.png": [["grasper,retract,gallbladder"], ["clipper,clip,cystic_duct"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000004.png": [["clipper,clip,cystic_artery"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000005.png": [["scissors,cut,cystic_duct"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000006.png": [["hook,dissect,gallbladder"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000007.png": [["hook,dissect,gallbladder"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000008.png": [["hook,dissect,gallbladder"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000009.png": [["grasper,retract,cystic_plate"], ["grasper,retract,gallbladder"], ["hook,dissect,gallbladder"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000010.png": [["grasper,grasp,specimen_bag"], ["grasper,null_verb,null_target"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000011.png": [["grasper,retract,liver"], ["bipolar,coagulate,liver"], ["grasper,null_verb,null_target"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000012.png": [["grasper,retract,liver"], ["bipolar,coagulate,liver"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000013.png": [["grasper,retract,gut"], ["grasper,retract,liver"], ["irrigator,irrigate,abdominal_wall_cavity"]]}
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# Run training on small test Imagen
subprocess.call(["venv/Scripts/python", "train.py", "-test", "-ts", timestamp])
subprocess.call([ "/home/kareemelgohary/Desktop/minImagen/train.py", "-test", "-ts", timestamp])

# Use small test Imagen to generate image
subprocess.call(["venv/Scripts/python", "inference.py", "-d", f"training_{timestamp}"])
2 changes: 1 addition & 1 deletion minimagen/Imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _p_mean_variance(self,
lowres_cond_img: torch.tensor = None,
lowres_noise_times: torch.tensor = None,
cond_scale: float = 1.,
model_output: torch.tensor = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]:
model_output: torch.tensor = None) -> tuple([torch.tensor, torch.tensor, torch.tensor]):
"""
Predicts noise component of `x` with `unet`, and then returns the corresponding forward process posterior
parameters given the predictions.
Expand Down
4 changes: 2 additions & 2 deletions minimagen/Unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def _generate_t_tokens(
self,
time: torch.tensor,
lowres_noise_times: torch.tensor
) -> tuple[torch.tensor, torch.tensor]:
) -> tuple([torch.tensor, torch.tensor]):
'''
Generate t and time_tokens

Expand Down Expand Up @@ -544,7 +544,7 @@ def _text_condition(
text_mask: torch.tensor,
t: torch.tensor,
time_tokens: torch.tensor
) -> tuple[torch.tensor, torch.tensor]:
) -> tuple([torch.tensor, torch.tensor]):
'''
Condition on text.

Expand Down
Binary file added minimagen/__pycache__/Imagen.cpython-38.pyc
Binary file not shown.
Binary file added minimagen/__pycache__/Unet.cpython-38.pyc
Binary file not shown.
Binary file added minimagen/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file not shown.
Binary file added minimagen/__pycache__/generate.cpython-38.pyc
Binary file not shown.
Binary file added minimagen/__pycache__/helpers.cpython-38.pyc
Binary file not shown.
Binary file added minimagen/__pycache__/layers.cpython-38.pyc
Binary file not shown.
Binary file added minimagen/__pycache__/t5.cpython-38.pyc
Binary file not shown.
Binary file added minimagen/__pycache__/training.cpython-38.pyc
Binary file not shown.
6 changes: 3 additions & 3 deletions minimagen/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ def _sample_random_times(self, batch_size: int, *, device: torch.device) -> torc
"""
return torch.randint(0, self.num_timesteps, (batch_size,), device=device, dtype=torch.long)

def _get_sampling_timesteps(self, batch: int, *, device: torch.device) -> list[torch.tensor]:
def _get_sampling_timesteps(self, batch: int, *, device: torch.device) -> list([torch.tensor]):
time_transitions = []

for i in reversed(range(self.num_timesteps)):
time_transitions.append((torch.full((batch,), i, device=device, dtype=torch.long)))

return time_transitions

def q_posterior(self, x_start: torch.tensor, x_t: torch.tensor, t: torch.tensor) -> tuple[torch.tensor,
def q_posterior(self, x_start: torch.tensor, x_t: torch.tensor, t: torch.tensor) -> tuple([torch.tensor,
torch.tensor,
torch.tensor]:
torch.tensor]):
"""
Calculates q_posterior parameters given a starting image :code:`x_start` (x_0) and a noised image :code:`x_t`.

Expand Down
6 changes: 3 additions & 3 deletions minimagen/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
self.activation = nn.SiLU()
self.project = nn.Conv2d(dim, dim_out, 3, padding=1)

def forward(self, x: torch.tensor, scale_shift: tuple[torch.tensor, torch.tensor] = None) -> torch.tensor:
def forward(self, x: torch.tensor, scale_shift: tuple([torch.tensor, torch.tensor]) = None) -> torch.tensor:
"""
Forward pass

Expand Down Expand Up @@ -269,7 +269,7 @@ class CrossEmbedLayer(nn.Module):
def __init__(
self,
dim_in: int,
kernel_sizes: tuple[int, ...],
kernel_sizes: tuple([int, ...]),
dim_out: int = None,
stride: int = 2
):
Expand Down Expand Up @@ -347,7 +347,7 @@ class Parallel(nn.Module):
"""
Passes input through parallel functions and then sums the result.
"""
def __init__(self, *fns: tuple[Callable, ...]):
def __init__(self, *fns: tuple([Callable, ...])):
super().__init__()
self.fns = nn.ModuleList(fns)

Expand Down
95 changes: 57 additions & 38 deletions minimagen/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Literal

from tqdm import tqdm

from torchvision import transforms, utils
import datasets
import PIL.Image
from einops import rearrange
Expand All @@ -22,7 +22,7 @@
from datasets import load_dataset
from datasets.utils.file_utils import get_datasets_user_agent
from resize_right import resize

from preparing_data import MinimagenDatasetNew , Rescale, ToTensor
from minimagen import Unet
from minimagen.helpers import exists
from minimagen.t5 import t5_encode_text
Expand Down Expand Up @@ -68,6 +68,7 @@ def __call__(self, batch):

# If the batch is empty after filtering
if not batch:
print("there is no images####################### jii from collator class")
return None

# Expand mask and encodings to len of elt in batch with greatest number of words
Expand All @@ -83,8 +84,10 @@ def __call__(self, batch):
elt['encoding'] = F.pad(elt['encoding'], (0, 0, 0, rem), 'constant', False)

# TODO: Should really be passing in `device` - find a more elegant way to do this
print(batch)
for didx, datum in enumerate(batch):
for tensor in datum.keys():
print(batch[didx][tensor])
batch[didx][tensor] = batch[didx][tensor].to(self.device)

return torch.utils.data.dataloader.default_collate(batch)
Expand Down Expand Up @@ -123,22 +126,23 @@ def _fetch_images(batch, num_threads, timeout=None, retries=0):
fetch_single_image_with_args = partial(_fetch_single_image, timeout=timeout, retries=retries)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
batch["image"] = list(executor.map(fetch_single_image_with_args, batch["image_url"]))
print("This is the batch ######################",batch)
return batch


def _fetch_single_image(image_url, timeout=None, retries=0):
for _ in range(retries + 1):
try:
request = urllib.request.Request(
image_url,
data=None,
headers={"user-agent": USER_AGENT},
)
with urllib.request.urlopen(request, timeout=timeout) as req:
# try:
# request = urllib.request.Request(
# image_url,
# data=None,
# headers={"user-agent": USER_AGENT},
# )
with urllib.request.urlopen(image_url, timeout=timeout) as req:
image = PIL.Image.open(io.BytesIO(req.read()))
break
except Exception:
image = None
# except Exception:
# image = None
return image


Expand Down Expand Up @@ -187,7 +191,7 @@ def get_minimagen_parser():
help="Maximum number of words allowed in a caption", default=64, type=int)
parser.add_argument("-s", "--IMG_SIDE_LEN", dest="IMG_SIDE_LEN", help="Side length of square Imagen output images",
default=128, type=int)
parser.add_argument("-e", "--EPOCHS", dest="EPOCHS", help="Number of training epochs", default=5, type=int)
parser.add_argument("-e", "--EPOCHS", dest="EPOCHS", help="Number of training epochs", default=10, type=int)
parser.add_argument("-t5", "--T5_NAME", dest="T5_NAME", help="Name of T5 encoder to use", default='t5_base',
type=str)
parser.add_argument("-f", "--TRAIN_VALID_FRAC", dest="TRAIN_VALID_FRAC",
Expand Down Expand Up @@ -236,6 +240,7 @@ def __init__(self, hf_dataset, *, encoder_name: str, max_length: int,
split = "train" if train else "validation"

self.urls = hf_dataset[f"{split}"]['image_url']
print("self.urls",self.urls)
self.captions = hf_dataset[f"{split}"]['caption']

if img_transform is None:
Expand All @@ -247,13 +252,14 @@ def __init__(self, hf_dataset, *, encoder_name: str, max_length: int,

def __len__(self):
return len(self.urls)

# need to change the gititem with my own data set
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()

img = _fetch_single_image(self.urls[idx])
img = _fetch_single_image(self.urls[idx],5,5)
if img is None:
print("we are here from class MinimagenDataset nothing data######################## the problem ",idx,img)
return None
elif self.img_transform:
img = self.img_transform(img)
Expand All @@ -269,7 +275,7 @@ def __getitem__(self, idx):
return {'image': img, 'encoding': enc, 'mask': msk}


def ConceptualCaptions(args, smalldata=False, testset=False):
def ConceptualCaptions(args, smalldata=True, testset=False):
"""
Load `conceptual captions dataset <https://ai.google.com/research/ConceptualCaptions/>`_

Expand All @@ -278,38 +284,49 @@ def ConceptualCaptions(args, smalldata=False, testset=False):
:param testset: Whether to return the testing set (vs training/valid)
:return: test_dataset if :code:`testset` else (train_dataset, valid_dataset)
"""
dset = load_dataset("conceptual_captions")
if smalldata:
num = 16
vi = dset['validation']['image_url'][:num]
vc = dset['validation']['caption'][:num]
ti = dset['train']['image_url'][:num]
tc = dset['train']['caption'][:num]
dset = datasets.Dataset = {'train': {
'image_url': ti,
'caption': tc,
}, 'num_rows': num,
'validation': {
'image_url': vi,
'caption': vc, }, 'num_rows': num}
# dset = load_dataset("conceptual_captions")
print("## dset from conceptualCaption ###########")
# print(len(dset),"## dset from conceptualCaption ###########")
# if smalldata:
# num = 16
# vi = dset['validation']['image_url'][:num]
# vc = dset['validation']['caption'][:num]
# ti = dset['train']['image_url'][:num]
# tc = dset['train']['caption'][:num]
# dset = datasets.Dataset = {'train': {
# 'image_url': ti,
# 'caption': tc,
# }, 'num_rows': num,
# 'validation': {
# 'image_url': vi,
# 'caption': vc, }, 'num_rows': num}
# print(dset,"hiii from dset")
with open('/home/kareemelgohary/Desktop/minImagen/MinImagen/dict.json') as f:
d = json.load(f)


if testset:
# Torch test dataset
test_dataset = MinimagenDataset(dset, max_length=args.MAX_NUM_WORDS, train=False, encoder_name=args.T5_NAME,
side_length=args.IMG_SIDE_LEN)
# test_dataset = MinimagenDataset(dset, max_length=args.MAX_NUM_WORDS, train=False, encoder_name=args.T5_NAME,
# side_length=args.IMG_SIDE_LEN)
test_dataset=MinimagenDatasetNew(d,"t5_base",28,transform=transforms.Compose([Rescale(256),ToTensor()]))
return test_dataset
else:
# Torch train/valid dataset
dataset_train_valid = MinimagenDataset(dset, max_length=args.MAX_NUM_WORDS, encoder_name=args.T5_NAME,
train=True,
side_length=args.IMG_SIDE_LEN)
## I need to pass here a csv file and it all processing happened the class of the MinimagenDataset
# dataset_train_valid = MinimagenDataset(dset, max_length=args.MAX_NUM_WORDS, encoder_name=args.T5_NAME,
# train=True,
# side_length=args.IMG_SIDE_LEN)
dataset_train_valid = MinimagenDatasetNew(d,"t5_base",28,transform=transforms.Compose([Rescale(256),ToTensor()]))
print("###########################Here we are at conceptualCaptions Function#############################",dataset_train_valid)

# Split into train/valid
train_size = int(args.TRAIN_VALID_FRAC * len(dataset_train_valid))
valid_size = len(dataset_train_valid) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(dataset_train_valid, [train_size, valid_size])
if args.VALID_NUM is not None:
valid_dataset.indices = valid_dataset.indices[:args.VALID_NUM + 1]
print("Train dataset conceptualCaptions Function ####################", type(train_dataset))
return train_dataset, valid_dataset


Expand Down Expand Up @@ -435,6 +452,7 @@ def train():
f'U-Nets Best Valid Losses: {[round(i.item(), 3) for i in best_loss]}\n\n')

best_loss = [torch.tensor(9999999) for i in range(len(unets))]
# train loop here
for epoch in range(args.EPOCHS):
print(f'\n{"-" * 20} EPOCH {epoch + 1} {"-" * 20}')
with training_dir():
Expand All @@ -450,6 +468,7 @@ def train():
with _Timeout(timeout):
# If batch is empty, move on to the next one
if not batch:
print("############it is empty batch from training loop")
continue

train()
Expand Down Expand Up @@ -541,11 +560,11 @@ def load_testing_parameters(args):
:param args: Arguments Namespace returned from parsing :func:`~.minimagen.training.get_minimagen_parser`.
"""
d = dict(
BATCH_SIZE=2,
MAX_NUM_WORDS=32,
BATCH_SIZE=4,
MAX_NUM_WORDS=28,
IMG_SIDE_LEN=128,
EPOCHS=2,
T5_NAME='t5_small',
EPOCHS=200,
T5_NAME='t5_base',
TRAIN_VALID_FRAC=0.5,
TIMESTEPS=25, # Do not make less than 20
OPTIM_LR=0.0001
Expand Down
67 changes: 67 additions & 0 deletions package_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
from datetime import datetime

import torch.utils.data
from torch import optim

from minimagen.Imagen import Imagen
from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest
from minimagen.generate import load_minimagen, load_params
from minimagen.t5 import get_encoded_dim
from minimagen.training import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts, \
create_directory, get_model_size, save_training_info, get_default_args, MinimagenTrain, \
load_testing_parameters

# Get device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Command line argument parser
parser = get_minimagen_parser()
args = parser.parse_args()

# Create training directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
dir_path = f"./training_{timestamp}"
training_dir = create_directory(dir_path)

# Replace some cmd line args to lower computational load.
args = load_testing_parameters(args)

# Load subset of Conceptual Captions dataset.
train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=True)

# Create dataloaders
dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS}
train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts)

# Use small U-Nets to lower computational load.
unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)]
unets = [Unet(**unet_params).to(device) for unet_params in unets_params]

# Specify MinImagen parameters
imagen_params = dict(
image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN),
timesteps=args.TIMESTEPS,
cond_drop_prob=0.15,
text_encoder_name=args.T5_NAME
)

# Create MinImagen from UNets with specified imagen parameters
imagen = Imagen(unets=unets, **imagen_params).to(device)

# Fill in unspecified arguments with defaults to record complete config (parameters) file
unets_params = [{**get_default_args(Unet), **i} for i in unets_params]
imagen_params = {**get_default_args(Imagen), **imagen_params}

# Get the size of the Imagen model in megabytes
model_size_MB = get_model_size(imagen)

# Save all training info (config files, model size, etc.)
save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir)

# Create optimizer
optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR)

# Train the MinImagen instance
MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30)
Loading