| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from PIL import Image |
| import json |
| import os |
| import subprocess |
| from transformers import BertTokenizer, BertModel |
| import wandb |
|
|
| |
| LATENT_DIM = 128 |
| HIDDEN_DIM = 256 |
|
|
| |
| class Text2ImageDataset(Dataset): |
| def __init__(self, image_dir, metadata_file): |
| self.image_dir = image_dir |
| with open(metadata_file, 'r') as f: |
| self.metadata = json.load(f) |
| self.transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5)) |
| ]) |
|
|
| def __len__(self): |
| return len(self.metadata) |
|
|
| def __getitem__(self, idx): |
| item = self.metadata[idx] |
| image_path = os.path.join(self.image_dir, item['file_name']) |
| |
| try: |
| image = Image.open(image_path).convert('RGBA') |
| except FileNotFoundError: |
| print(f"Image not found: {image_path}") |
| return None, None |
| except Exception as e: |
| print(f"Error loading image {image_path}: {e}") |
| return None, None |
|
|
| image = self.transform(image) |
| prompt = str(item['description']) |
| return image, prompt |
|
|
| |
| class TextEncoder(nn.Module): |
| def __init__(self, hidden_size, output_size): |
| super(TextEncoder, self).__init__() |
| self.bert = BertModel.from_pretrained('bert-base-uncased') |
| self.fc = nn.Linear(self.bert.config.hidden_size, output_size) |
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| return self.fc(outputs.last_hidden_state[:, 0, :]) |
|
|
| |
| class CVAE(nn.Module): |
| def __init__(self, text_encoder): |
| super(CVAE, self).__init__() |
| self.text_encoder = text_encoder |
|
|
| |
| self.encoder = nn.Sequential( |
| nn.Conv2d(4, 32, 3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(32, 64, 3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(64, 128, 3, stride=2, padding=1), |
| nn.ReLU(), |
| nn.Flatten(), |
| nn.Linear(128 * 4 * 4, HIDDEN_DIM) |
| ) |
|
|
| self.fc_mu = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) |
| self.fc_logvar = nn.Linear(HIDDEN_DIM + HIDDEN_DIM, LATENT_DIM) |
|
|
| |
| self.decoder_input = nn.Linear(LATENT_DIM + HIDDEN_DIM, 128 * 4 * 4) |
| self.decoder = nn.Sequential( |
| nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), |
| nn.ReLU(), |
| nn.Conv2d(32, 4, 3, stride=1, padding=1), |
| nn.Tanh() |
| ) |
|
|
| def encode(self, x, c): |
| x = self.encoder(x) |
| x = torch.cat([x, c], dim=1) |
| mu = self.fc_mu(x) |
| logvar = self.fc_logvar(x) |
| return mu, logvar |
|
|
| def decode(self, z, c): |
| z = torch.cat([z, c], dim=1) |
| x = self.decoder_input(z) |
| x = x.view(-1, 128, 4, 4) |
| return self.decoder(x) |
|
|
| def reparameterize(self, mu, logvar): |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| return mu + eps * std |
|
|
| def forward(self, x, c): |
| mu, logvar = self.encode(x, c) |
| z = self.reparameterize(mu, logvar) |
| return self.decode(z, c), mu, logvar |
|
|
| |
| def loss_function(recon_x, x, mu, logvar): |
| BCE = nn.functional.mse_loss(recon_x, x, reduction='sum') |
| KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) |
| return BCE + KLD |
|
|
| |
| def train(model, train_loader, optimizer, device, tokenizer): |
| model.train() |
| train_loss = 0 |
| for batch_idx, (data, prompt) in enumerate(train_loader): |
| data = data.to(device) |
| optimizer.zero_grad() |
| |
| encoded_input = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") |
| input_ids = encoded_input['input_ids'].to(device) |
| attention_mask = encoded_input['attention_mask'].to(device) |
| |
| text_encoding = model.text_encoder(input_ids, attention_mask) |
| |
| recon_batch, mu, logvar = model(data, text_encoding) |
| loss = loss_function(recon_batch, data, mu, logvar) |
| loss.backward() |
| train_loss += loss.item() |
| optimizer.step() |
| |
| |
| wandb.log({ |
| "batch_loss": loss.item(), |
| "batch_reconstruction_loss": nn.functional.mse_loss(recon_batch, data, reduction='mean').item(), |
| "batch_kl_divergence": (-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / data.size(0)).item() |
| }) |
| |
| avg_loss = train_loss / len(train_loader.dataset) |
| return avg_loss |
|
|
| |
| def main(): |
|
|
| NUM_EPOCHS = 500 |
| BATCH_SIZE = 128 |
| LEARNING_RATE = 1e-4 |
|
|
| |
| SAVE_INTERVAL = 25 |
| SAVE_INTERVAL_IMAGE = 1 |
| PROJECT_NAME = "BitRoss" |
| MODEL_NAME = "BitRoss" |
| SAVE_DIR = "/models/BitRoss/" |
|
|
| if(os.path.exists(SAVE_DIR) == False): |
| os.makedirs(SAVE_DIR) |
|
|
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
| if not os.path.exists(SAVE_DIR): |
| os.makedirs(SAVE_DIR) |
|
|
| DATA_DIR = "./trainingData/" |
| METADATA_FILE = "./trainingData/metadata.json" |
| |
|
|
| |
| wandb.init(project=PROJECT_NAME, config={ |
| "LATENT_DIM": LATENT_DIM, |
| "HIDDEN_DIM": HIDDEN_DIM, |
| "NUM_EPOCHS": NUM_EPOCHS, |
| "BATCH_SIZE": BATCH_SIZE, |
| "LEARNING_RATE": LEARNING_RATE, |
| "SAVE_INTERVAL": SAVE_INTERVAL, |
| "MODEL_NAME": MODEL_NAME |
| }) |
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| dataset = Text2ImageDataset(DATA_DIR, METADATA_FILE) |
| train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) |
|
|
| text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM) |
| model = CVAE(text_encoder).to(device) |
| optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) |
|
|
| |
| wandb.watch(model, log="all", log_freq=100) |
|
|
| for epoch in range(1, NUM_EPOCHS + 1): |
| train_loss = train(model, train_loader, optimizer, device, tokenizer) |
| print(f'Epoch {epoch}, Loss: {train_loss:.4f}') |
| |
| |
| wandb.log({ |
| "epoch": epoch, |
| "train_loss": train_loss, |
| }) |
|
|
| |
| if epoch % SAVE_INTERVAL_IMAGE == 0: |
| |
| output_image = f"{SAVE_DIR}output_epoch_{epoch}.png" |
| |
| |
| from generate import generate_image |
| prompt = "A blue sword made of diamond" |
| generated_image = generate_image(model, prompt, device) |
| generated_image.save(output_image) |
| |
| |
| wandb.log({ |
| "generated_image": wandb.Image(output_image, caption=f"Generated at epoch {epoch} with prompt {prompt}") |
| }) |
|
|
| |
| if epoch % SAVE_INTERVAL == 0: |
| model_save_path = f"{SAVE_DIR}{MODEL_NAME}_epoch_{epoch}.pth" |
| torch.save(model.state_dict(), model_save_path) |
| print(f"Model saved to {model_save_path}") |
|
|
| |
| if epoch % 10 == 0: |
| model.eval() |
| with torch.no_grad(): |
| sample_data, sample_prompt = next(iter(train_loader)) |
| sample_data = sample_data[:4].to(device) |
| encoded_input = tokenizer(sample_prompt[:4], padding=True, truncation=True, return_tensors="pt") |
| input_ids = encoded_input['input_ids'].to(device) |
| attention_mask = encoded_input['attention_mask'].to(device) |
| text_encoding = model.text_encoder(input_ids, attention_mask) |
| recon_batch, _, _ = model(sample_data, text_encoding) |
| |
| |
| original_images = [transforms.ToPILImage()((sample_data[i] * 0.5 + 0.5).cpu()) for i in range(4)] |
| reconstructed_images = [transforms.ToPILImage()((recon_batch[i] * 0.5 + 0.5).cpu()) for i in range(4)] |
| |
| wandb.log({ |
| f"original_vs_reconstructed_{i}": [wandb.Image(original_images[i], caption=f"Original {i}"), |
| wandb.Image(reconstructed_images[i], caption=f"Reconstructed {i}")] |
| for i in range(4) |
| }) |
|
|
| wandb.finish() |
|
|
| if __name__ == "__main__": |
| main() |