diff --git a/__pycache__/preparing_data.cpython-38.pyc b/__pycache__/preparing_data.cpython-38.pyc new file mode 100644 index 0000000..949ad30 Binary files /dev/null and b/__pycache__/preparing_data.cpython-38.pyc differ diff --git a/captions.txt b/captions.txt index f6c9d0c..5c8914a 100644 --- a/captions.txt +++ b/captions.txt @@ -1,2 +1,2 @@ -a happy dog -a big red house \ No newline at end of file +grasper retract peritoneum +bipolar coagulate blood_vesse diff --git a/dict.json b/dict.json new file mode 100644 index 0000000..df848b4 --- /dev/null +++ b/dict.json @@ -0,0 +1 @@ +{"/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000000.png": [["grasper,retract,gallbladder"], ["grasper,retract,gut"], ["hook,dissect,omentum"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000001.png": [["grasper,retract,omentum"], ["hook,dissect,omentum"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000002.png": [["grasper,retract,gallbladder"], ["hook,dissect,cystic_plate"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000003.png": [["grasper,retract,gallbladder"], ["clipper,clip,cystic_duct"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000004.png": [["clipper,clip,cystic_artery"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID230/000005.png": [["scissors,cut,cystic_duct"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000006.png": [["hook,dissect,gallbladder"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000007.png": [["hook,dissect,gallbladder"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000008.png": [["hook,dissect,gallbladder"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000009.png": [["grasper,retract,cystic_plate"], ["grasper,retract,gallbladder"], ["hook,dissect,gallbladder"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000010.png": [["grasper,grasp,specimen_bag"], ["grasper,null_verb,null_target"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000011.png": [["grasper,retract,liver"], ["bipolar,coagulate,liver"], ["grasper,null_verb,null_target"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000012.png": [["grasper,retract,liver"], ["bipolar,coagulate,liver"]], "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample/data/VID231/000013.png": [["grasper,retract,gut"], ["grasper,retract,liver"], ["irrigator,irrigate,abdominal_wall_cavity"]]} \ No newline at end of file diff --git a/main.py b/main.py index e4c4af2..a18d456 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Run training on small test Imagen -subprocess.call(["venv/Scripts/python", "train.py", "-test", "-ts", timestamp]) +subprocess.call([ "/home/kareemelgohary/Desktop/minImagen/train.py", "-test", "-ts", timestamp]) # Use small test Imagen to generate image subprocess.call(["venv/Scripts/python", "inference.py", "-d", f"training_{timestamp}"]) \ No newline at end of file diff --git a/minimagen/Imagen.py b/minimagen/Imagen.py index e5bafe1..2716e75 100644 --- a/minimagen/Imagen.py +++ b/minimagen/Imagen.py @@ -269,7 +269,7 @@ def _p_mean_variance(self, lowres_cond_img: torch.tensor = None, lowres_noise_times: torch.tensor = None, cond_scale: float = 1., - model_output: torch.tensor = None) -> tuple[torch.tensor, torch.tensor, torch.tensor]: + model_output: torch.tensor = None) -> tuple([torch.tensor, torch.tensor, torch.tensor]): """ Predicts noise component of `x` with `unet`, and then returns the corresponding forward process posterior parameters given the predictions. diff --git a/minimagen/Unet.py b/minimagen/Unet.py index 96afdf3..60067ee 100644 --- a/minimagen/Unet.py +++ b/minimagen/Unet.py @@ -509,7 +509,7 @@ def _generate_t_tokens( self, time: torch.tensor, lowres_noise_times: torch.tensor - ) -> tuple[torch.tensor, torch.tensor]: + ) -> tuple([torch.tensor, torch.tensor]): ''' Generate t and time_tokens @@ -544,7 +544,7 @@ def _text_condition( text_mask: torch.tensor, t: torch.tensor, time_tokens: torch.tensor - ) -> tuple[torch.tensor, torch.tensor]: + ) -> tuple([torch.tensor, torch.tensor]): ''' Condition on text. diff --git a/minimagen/__pycache__/Imagen.cpython-38.pyc b/minimagen/__pycache__/Imagen.cpython-38.pyc new file mode 100644 index 0000000..9595e2d Binary files /dev/null and b/minimagen/__pycache__/Imagen.cpython-38.pyc differ diff --git a/minimagen/__pycache__/Unet.cpython-38.pyc b/minimagen/__pycache__/Unet.cpython-38.pyc new file mode 100644 index 0000000..99c941e Binary files /dev/null and b/minimagen/__pycache__/Unet.cpython-38.pyc differ diff --git a/minimagen/__pycache__/__init__.cpython-38.pyc b/minimagen/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..60af119 Binary files /dev/null and b/minimagen/__pycache__/__init__.cpython-38.pyc differ diff --git a/minimagen/__pycache__/diffusion_model.cpython-38.pyc b/minimagen/__pycache__/diffusion_model.cpython-38.pyc new file mode 100644 index 0000000..c3f7843 Binary files /dev/null and b/minimagen/__pycache__/diffusion_model.cpython-38.pyc differ diff --git a/minimagen/__pycache__/generate.cpython-38.pyc b/minimagen/__pycache__/generate.cpython-38.pyc new file mode 100644 index 0000000..828db7d Binary files /dev/null and b/minimagen/__pycache__/generate.cpython-38.pyc differ diff --git a/minimagen/__pycache__/helpers.cpython-38.pyc b/minimagen/__pycache__/helpers.cpython-38.pyc new file mode 100644 index 0000000..3310e42 Binary files /dev/null and b/minimagen/__pycache__/helpers.cpython-38.pyc differ diff --git a/minimagen/__pycache__/layers.cpython-38.pyc b/minimagen/__pycache__/layers.cpython-38.pyc new file mode 100644 index 0000000..420e8e5 Binary files /dev/null and b/minimagen/__pycache__/layers.cpython-38.pyc differ diff --git a/minimagen/__pycache__/t5.cpython-38.pyc b/minimagen/__pycache__/t5.cpython-38.pyc new file mode 100644 index 0000000..aea4661 Binary files /dev/null and b/minimagen/__pycache__/t5.cpython-38.pyc differ diff --git a/minimagen/__pycache__/training.cpython-38.pyc b/minimagen/__pycache__/training.cpython-38.pyc new file mode 100644 index 0000000..3d34e03 Binary files /dev/null and b/minimagen/__pycache__/training.cpython-38.pyc differ diff --git a/minimagen/diffusion_model.py b/minimagen/diffusion_model.py index aad0769..650432c 100644 --- a/minimagen/diffusion_model.py +++ b/minimagen/diffusion_model.py @@ -78,7 +78,7 @@ def _sample_random_times(self, batch_size: int, *, device: torch.device) -> torc """ return torch.randint(0, self.num_timesteps, (batch_size,), device=device, dtype=torch.long) - def _get_sampling_timesteps(self, batch: int, *, device: torch.device) -> list[torch.tensor]: + def _get_sampling_timesteps(self, batch: int, *, device: torch.device) -> list([torch.tensor]): time_transitions = [] for i in reversed(range(self.num_timesteps)): @@ -86,9 +86,9 @@ def _get_sampling_timesteps(self, batch: int, *, device: torch.device) -> list[t return time_transitions - def q_posterior(self, x_start: torch.tensor, x_t: torch.tensor, t: torch.tensor) -> tuple[torch.tensor, + def q_posterior(self, x_start: torch.tensor, x_t: torch.tensor, t: torch.tensor) -> tuple([torch.tensor, torch.tensor, - torch.tensor]: + torch.tensor]): """ Calculates q_posterior parameters given a starting image :code:`x_start` (x_0) and a noised image :code:`x_t`. diff --git a/minimagen/layers.py b/minimagen/layers.py index 0dafdab..9872b21 100644 --- a/minimagen/layers.py +++ b/minimagen/layers.py @@ -128,7 +128,7 @@ def __init__( self.activation = nn.SiLU() self.project = nn.Conv2d(dim, dim_out, 3, padding=1) - def forward(self, x: torch.tensor, scale_shift: tuple[torch.tensor, torch.tensor] = None) -> torch.tensor: + def forward(self, x: torch.tensor, scale_shift: tuple([torch.tensor, torch.tensor]) = None) -> torch.tensor: """ Forward pass @@ -269,7 +269,7 @@ class CrossEmbedLayer(nn.Module): def __init__( self, dim_in: int, - kernel_sizes: tuple[int, ...], + kernel_sizes: tuple([int, ...]), dim_out: int = None, stride: int = 2 ): @@ -347,7 +347,7 @@ class Parallel(nn.Module): """ Passes input through parallel functions and then sums the result. """ - def __init__(self, *fns: tuple[Callable, ...]): + def __init__(self, *fns: tuple([Callable, ...])): super().__init__() self.fns = nn.ModuleList(fns) diff --git a/minimagen/training.py b/minimagen/training.py index 43452a8..9bf4d77 100644 --- a/minimagen/training.py +++ b/minimagen/training.py @@ -11,7 +11,7 @@ from typing import Literal from tqdm import tqdm - +from torchvision import transforms, utils import datasets import PIL.Image from einops import rearrange @@ -22,7 +22,7 @@ from datasets import load_dataset from datasets.utils.file_utils import get_datasets_user_agent from resize_right import resize - +from preparing_data import MinimagenDatasetNew , Rescale, ToTensor from minimagen import Unet from minimagen.helpers import exists from minimagen.t5 import t5_encode_text @@ -68,6 +68,7 @@ def __call__(self, batch): # If the batch is empty after filtering if not batch: + print("there is no images####################### jii from collator class") return None # Expand mask and encodings to len of elt in batch with greatest number of words @@ -83,8 +84,10 @@ def __call__(self, batch): elt['encoding'] = F.pad(elt['encoding'], (0, 0, 0, rem), 'constant', False) # TODO: Should really be passing in `device` - find a more elegant way to do this + print(batch) for didx, datum in enumerate(batch): for tensor in datum.keys(): + print(batch[didx][tensor]) batch[didx][tensor] = batch[didx][tensor].to(self.device) return torch.utils.data.dataloader.default_collate(batch) @@ -123,22 +126,23 @@ def _fetch_images(batch, num_threads, timeout=None, retries=0): fetch_single_image_with_args = partial(_fetch_single_image, timeout=timeout, retries=retries) with ThreadPoolExecutor(max_workers=num_threads) as executor: batch["image"] = list(executor.map(fetch_single_image_with_args, batch["image_url"])) + print("This is the batch ######################",batch) return batch def _fetch_single_image(image_url, timeout=None, retries=0): for _ in range(retries + 1): - try: - request = urllib.request.Request( - image_url, - data=None, - headers={"user-agent": USER_AGENT}, - ) - with urllib.request.urlopen(request, timeout=timeout) as req: + # try: + # request = urllib.request.Request( + # image_url, + # data=None, + # headers={"user-agent": USER_AGENT}, + # ) + with urllib.request.urlopen(image_url, timeout=timeout) as req: image = PIL.Image.open(io.BytesIO(req.read())) break - except Exception: - image = None + # except Exception: + # image = None return image @@ -187,7 +191,7 @@ def get_minimagen_parser(): help="Maximum number of words allowed in a caption", default=64, type=int) parser.add_argument("-s", "--IMG_SIDE_LEN", dest="IMG_SIDE_LEN", help="Side length of square Imagen output images", default=128, type=int) - parser.add_argument("-e", "--EPOCHS", dest="EPOCHS", help="Number of training epochs", default=5, type=int) + parser.add_argument("-e", "--EPOCHS", dest="EPOCHS", help="Number of training epochs", default=10, type=int) parser.add_argument("-t5", "--T5_NAME", dest="T5_NAME", help="Name of T5 encoder to use", default='t5_base', type=str) parser.add_argument("-f", "--TRAIN_VALID_FRAC", dest="TRAIN_VALID_FRAC", @@ -236,6 +240,7 @@ def __init__(self, hf_dataset, *, encoder_name: str, max_length: int, split = "train" if train else "validation" self.urls = hf_dataset[f"{split}"]['image_url'] + print("self.urls",self.urls) self.captions = hf_dataset[f"{split}"]['caption'] if img_transform is None: @@ -247,13 +252,14 @@ def __init__(self, hf_dataset, *, encoder_name: str, max_length: int, def __len__(self): return len(self.urls) - + # need to change the gititem with my own data set def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() - img = _fetch_single_image(self.urls[idx]) + img = _fetch_single_image(self.urls[idx],5,5) if img is None: + print("we are here from class MinimagenDataset nothing data######################## the problem ",idx,img) return None elif self.img_transform: img = self.img_transform(img) @@ -269,7 +275,7 @@ def __getitem__(self, idx): return {'image': img, 'encoding': enc, 'mask': msk} -def ConceptualCaptions(args, smalldata=False, testset=False): +def ConceptualCaptions(args, smalldata=True, testset=False): """ Load `conceptual captions dataset `_ @@ -278,31 +284,41 @@ def ConceptualCaptions(args, smalldata=False, testset=False): :param testset: Whether to return the testing set (vs training/valid) :return: test_dataset if :code:`testset` else (train_dataset, valid_dataset) """ - dset = load_dataset("conceptual_captions") - if smalldata: - num = 16 - vi = dset['validation']['image_url'][:num] - vc = dset['validation']['caption'][:num] - ti = dset['train']['image_url'][:num] - tc = dset['train']['caption'][:num] - dset = datasets.Dataset = {'train': { - 'image_url': ti, - 'caption': tc, - }, 'num_rows': num, - 'validation': { - 'image_url': vi, - 'caption': vc, }, 'num_rows': num} + # dset = load_dataset("conceptual_captions") + print("## dset from conceptualCaption ###########") + # print(len(dset),"## dset from conceptualCaption ###########") + # if smalldata: + # num = 16 + # vi = dset['validation']['image_url'][:num] + # vc = dset['validation']['caption'][:num] + # ti = dset['train']['image_url'][:num] + # tc = dset['train']['caption'][:num] + # dset = datasets.Dataset = {'train': { + # 'image_url': ti, + # 'caption': tc, + # }, 'num_rows': num, + # 'validation': { + # 'image_url': vi, + # 'caption': vc, }, 'num_rows': num} + # print(dset,"hiii from dset") + with open('/home/kareemelgohary/Desktop/minImagen/MinImagen/dict.json') as f: + d = json.load(f) + if testset: # Torch test dataset - test_dataset = MinimagenDataset(dset, max_length=args.MAX_NUM_WORDS, train=False, encoder_name=args.T5_NAME, - side_length=args.IMG_SIDE_LEN) + # test_dataset = MinimagenDataset(dset, max_length=args.MAX_NUM_WORDS, train=False, encoder_name=args.T5_NAME, + # side_length=args.IMG_SIDE_LEN) + test_dataset=MinimagenDatasetNew(d,"t5_base",28,transform=transforms.Compose([Rescale(256),ToTensor()])) return test_dataset else: # Torch train/valid dataset - dataset_train_valid = MinimagenDataset(dset, max_length=args.MAX_NUM_WORDS, encoder_name=args.T5_NAME, - train=True, - side_length=args.IMG_SIDE_LEN) + ## I need to pass here a csv file and it all processing happened the class of the MinimagenDataset + # dataset_train_valid = MinimagenDataset(dset, max_length=args.MAX_NUM_WORDS, encoder_name=args.T5_NAME, + # train=True, + # side_length=args.IMG_SIDE_LEN) + dataset_train_valid = MinimagenDatasetNew(d,"t5_base",28,transform=transforms.Compose([Rescale(256),ToTensor()])) + print("###########################Here we are at conceptualCaptions Function#############################",dataset_train_valid) # Split into train/valid train_size = int(args.TRAIN_VALID_FRAC * len(dataset_train_valid)) @@ -310,6 +326,7 @@ def ConceptualCaptions(args, smalldata=False, testset=False): train_dataset, valid_dataset = torch.utils.data.random_split(dataset_train_valid, [train_size, valid_size]) if args.VALID_NUM is not None: valid_dataset.indices = valid_dataset.indices[:args.VALID_NUM + 1] + print("Train dataset conceptualCaptions Function ####################", type(train_dataset)) return train_dataset, valid_dataset @@ -435,6 +452,7 @@ def train(): f'U-Nets Best Valid Losses: {[round(i.item(), 3) for i in best_loss]}\n\n') best_loss = [torch.tensor(9999999) for i in range(len(unets))] + # train loop here for epoch in range(args.EPOCHS): print(f'\n{"-" * 20} EPOCH {epoch + 1} {"-" * 20}') with training_dir(): @@ -450,6 +468,7 @@ def train(): with _Timeout(timeout): # If batch is empty, move on to the next one if not batch: + print("############it is empty batch from training loop") continue train() @@ -541,11 +560,11 @@ def load_testing_parameters(args): :param args: Arguments Namespace returned from parsing :func:`~.minimagen.training.get_minimagen_parser`. """ d = dict( - BATCH_SIZE=2, - MAX_NUM_WORDS=32, + BATCH_SIZE=4, + MAX_NUM_WORDS=28, IMG_SIDE_LEN=128, - EPOCHS=2, - T5_NAME='t5_small', + EPOCHS=200, + T5_NAME='t5_base', TRAIN_VALID_FRAC=0.5, TIMESTEPS=25, # Do not make less than 20 OPTIM_LR=0.0001 diff --git a/package_train.py b/package_train.py new file mode 100644 index 0000000..9d6450a --- /dev/null +++ b/package_train.py @@ -0,0 +1,67 @@ +import os +from datetime import datetime + +import torch.utils.data +from torch import optim + +from minimagen.Imagen import Imagen +from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest +from minimagen.generate import load_minimagen, load_params +from minimagen.t5 import get_encoded_dim +from minimagen.training import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts, \ + create_directory, get_model_size, save_training_info, get_default_args, MinimagenTrain, \ + load_testing_parameters + +# Get device +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# Command line argument parser +parser = get_minimagen_parser() +args = parser.parse_args() + +# Create training directory +timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") +dir_path = f"./training_{timestamp}" +training_dir = create_directory(dir_path) + +# Replace some cmd line args to lower computational load. +args = load_testing_parameters(args) + +# Load subset of Conceptual Captions dataset. +train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=True) + +# Create dataloaders +dl_opts = {**get_minimagen_dl_opts(device), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS} +train_dataloader = torch.utils.data.DataLoader(train_dataset, **dl_opts) +valid_dataloader = torch.utils.data.DataLoader(valid_dataset, **dl_opts) + +# Use small U-Nets to lower computational load. +unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)] +unets = [Unet(**unet_params).to(device) for unet_params in unets_params] + +# Specify MinImagen parameters +imagen_params = dict( + image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN), + timesteps=args.TIMESTEPS, + cond_drop_prob=0.15, + text_encoder_name=args.T5_NAME +) + +# Create MinImagen from UNets with specified imagen parameters +imagen = Imagen(unets=unets, **imagen_params).to(device) + +# Fill in unspecified arguments with defaults to record complete config (parameters) file +unets_params = [{**get_default_args(Unet), **i} for i in unets_params] +imagen_params = {**get_default_args(Imagen), **imagen_params} + +# Get the size of the Imagen model in megabytes +model_size_MB = get_model_size(imagen) + +# Save all training info (config files, model size, etc.) +save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir) + +# Create optimizer +optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR) + +# Train the MinImagen instance +MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30) \ No newline at end of file diff --git a/preparing_data.py b/preparing_data.py new file mode 100644 index 0000000..3d5b037 --- /dev/null +++ b/preparing_data.py @@ -0,0 +1,233 @@ +# importing necessary packages +from PIL import Image +from skimage import io, transform +import pandas as pd +import csv +import matplotlib.pyplot as plt +import numpy as np +import os +import json +from itertools import compress +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms, utils +from minimagen.t5 import t5_encode_text + + + +dataset_path = "/home/kareemelgohary/Downloads/Sample_data/CholecT50_sample-20221208T171010Z-001/CholecT50_sample" +data_path = os.path.join(dataset_path, 'data') +triplet_path = os.path.join(dataset_path, 'triplet') +dict_path = os.path.join(dataset_path, 'dict') +video_names = os.listdir(data_path) +video_path= os.path.join(data_path,video_names[0]) +print("Dataset paths successfully defined!") +# print(triplet_path) +# print(video_names) + +# Function return the videos pathes +def video_pathes(data_path,video_names): + """ + Defines simple function to return video pathes as List + Args: + String the path of the data + List of Video names + Returns: + List of pathes for each video folder. + """ + list_video_pathes= [] + for video_name in video_names: + list_video_pathes.append(os.path.join(data_path,video_name)) + return sorted(list_video_pathes) + +vp= video_pathes(data_path,video_names) +fram=sorted(os.listdir(vp[0])) +# print(fram[0]) + +# Create function to get the pathes of the frames from the videos in sequentiol orders. +def video_frames(video_pathes): + """ + Defines simple function to return the frames of the videos as List in sequential way + Args: + List of pathes of the videos + Return: + List of Frames + """ + frames_pathes_list=[] + for video_name in video_pathes: + frames = sorted(os.listdir(video_name)) + for frame in frames: + frames_pathes_list.append(os.path.join(video_name,frame)) + return frames_pathes_list + +# print(video_frames(vp)[0:6]) +video_frames_pathes = video_frames(vp) + +# Create dictionary mapping triplet ids to readable label + +with open(os.path.join(dict_path, 'triplet.txt'), 'r') as f: + triplet_info = f.readlines() + triplet_dict = {} + for l in triplet_info: + triplet_id, triplet_label = l.split(':') + triplet_dict[int(triplet_id)] = triplet_label.rstrip() +# print(triplet_dict) + + +# Create Functoin to get the tripplet +def Tripplite_label(triplet_path,video_names): + data_set={} + for video_name in video_names: + with open(os.path.join(triplet_path, video_name + '.txt'), mode='r') as infile: + reader = csv.reader(infile) + + for line in reader: + line = np.array(line, np.int64) + frame_id, triplet_label = line[0], line[1:] + image_path = os.path.join(data_path, video_name, "%06d.png" %frame_id) + image = np.array(Image.open(image_path), np.float32) / 255.0 + indices = list(compress(range(len(triplet_label)), triplet_label)) + + + data_set.update({image_path:indices}) + return data_set + +# print(Tripplite_label(triplet_path,video_names)) +data_set= Tripplite_label(triplet_path,video_names) + + + +# Create function to generate the image pathes and its triplets. +def mapping(data_set,triplet_dict): + """ + Function to create dictionary consists of the the image paths and its labels in English + Args: + Data set (Dictionary) the pathes and its binary labels + triplet (Dictionary) the indx and the english words of triplet + Return: + Dictionary contain the image pathes and its triplet. the triplet could be more than one list + """ + data = {} + for i in data_set: + indeces_singel_list = data_set[i] + if len(indeces_singel_list)!=0: + label = [] + for indx in indeces_singel_list: + label.append([triplet_dict[indx]]) + data.update({i:label}) + pass + return data + + + + +# print(mapping(data_set,triplet_dict)) +# d= mapping(data_set,triplet_dict) + + + +# # create json object from dictionary +# json = json.dumps(d) +# # open file for writing, "w" +# f = open("dict.json","w") +# # write json object to file +# f.write(json) +# # close file +# f.close() + +# print(list(d.items())[5][1][0]) +# example= list(d.items())[5][1][0] +# str_example= ' '.join(example).replace(',', ' ') +# # str_example = str_example.replace(',', ' ') +# print(str_example) + +# Creating Dataset Class +class MinimagenDatasetNew(Dataset): + """Triplet DataSet """ + def __init__(self,dect_dataset,encoder_name,max_length,transform=None): + """ + Args: + Dictionary file contain the image pathes and its triplet + """ + self.Triplet_data = list(dect_dataset.items()) + self.encoder_name = encoder_name + self.max_length = max_length + self.transform = transform + def __len__(self): + return len(self.Triplet_data) + + def __getitem__(self,idx): + if torch.is_tensor(idx): + idx = idx.tolist() + img_name = self.Triplet_data[idx][0] + description = self.Triplet_data[idx][1][0] + str_description = ' '.join(description).replace(',', ' ') + image = io.imread(img_name) + # convert_tensor = transforms.ToTensor() + # converted_image= convert_tensor(image) + + enc, msk = t5_encode_text([str_description], self.encoder_name, self.max_length) + + sample = {'image': image, 'encoding': enc, 'mask': msk} + if self.transform: + sample = self.transform(sample) + return sample + +class Rescale(object): + """ Rescale the image in sample to given size""" + """ + Args: + output_size (tuple or int)If tuple, output is + matched to output_size. If int, smaller of image edges is matched + to output_size keeping aspect ratio the same. + """ + + + def __init__(self,output_size): + assert isinstance(output_size,(int,tuple)) + self.output_size = output_size + + def __call__(self,sample): + image = sample['image'] + enc = sample['encoding'] + msk= sample['mask'] + h,w=image.shape[:2] + if isinstance(self.output_size,int): + if h>w: + new_h,new_w = self.output_size * h/w , self.output_size + else: + new_h,new_w = self.output_size,self.output_size * w/h + else: + new_h , new_w = self.output_size + new_h, new_w = int(new_h), int(new_w) + img = transform.resize(image, (new_h, new_w)) + return {'image': image, 'encoding': enc, 'mask': msk} + +class ToTensor(object): + """Convert ndarrays in sample to Tensors.""" + def __call__(self,sample): + image = sample['image'] + enc = sample['encoding'] + msk= sample['mask'] + image = image.transpose((2, 0, 1)) + return {'image': torch.from_numpy(image), 'encoding': enc, 'mask': msk} + + + + + + + + +with open('/home/kareemelgohary/Desktop/minImagen/MinImagen/dict.json') as f: + d = json.load(f) +T_dataset=MinimagenDatasetNew(d,"t5_base",28,transform=transforms.Compose([Rescale(256),ToTensor()])) +print("script ends here") +# fig = plt.figure() +for i in range(len(T_dataset)): + sample = T_dataset[i] + print(i, sample['image'].shape,sample['encoding'].shape,sample['mask'].shape) + if i ==7: + break + + diff --git a/train.py b/train.py index 7a9b8d4..8dd9004 100644 --- a/train.py +++ b/train.py @@ -3,7 +3,7 @@ import torch.utils.data from torch import optim - +from preparing_data import MinimagenDatasetNew from minimagen.Imagen import Imagen from minimagen.Unet import Unet, Base, Super, BaseTest, SuperTest diff --git a/try.py b/try.py new file mode 100644 index 0000000..105260b --- /dev/null +++ b/try.py @@ -0,0 +1,35 @@ +import inspect +import json +import os +import signal +from argparse import ArgumentParser +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager +from functools import partial +import io +import urllib +from typing import Literal + +from tqdm import tqdm + +import datasets +import PIL.Image +from einops import rearrange +import torch.utils.data +import torch.nn.functional as F +from torchvision.transforms import Compose, ToTensor + +from datasets import load_dataset +from datasets.utils.file_utils import get_datasets_user_agent +from resize_right import resize + +from minimagen import Unet +from minimagen.helpers import exists +from minimagen.t5 import t5_encode_text + + +enc, msk =t5_encode_text("fish walks on the moon", "t5_small", 256) + + +print(enc) +print(msk) \ No newline at end of file