| | |
| | import os |
| | from typing import (Generic, Iterable, Iterator, List, Optional, Sequence, |
| | Sized, TypeVar, Union) |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from torch.utils.data import BatchSampler, Dataset, Sampler |
| |
|
| | |
| | CUSTOM_ASPECT_RATIOS = { |
| | "0.5676": [336, 592], |
| | "1.7619": [592, 336], |
| | "0.5682": [400, 704], |
| | "0.5556": [320, 576], |
| | "1.7600": [704, 400], |
| | "0.5319": [400, 752], |
| | "1.8000": [576, 320], |
| | "0.5128": [320, 624], |
| | "1.8800": [752, 400], |
| | "1.9000": [608, 320], |
| | "0.4237": [400, 944], |
| | } |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | ASPECT_RATIO_512 = { |
| | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], |
| | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], |
| | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], |
| | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], |
| | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], |
| | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], |
| | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], |
| | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], |
| | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], |
| | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] |
| | } |
| | ASPECT_RATIO_RANDOM_CROP_512 = { |
| | '0.42': [320.0, 768.0], '0.5': [352.0, 704.0], |
| | '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0], |
| | '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], |
| | '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0], |
| | '2.0': [704.0, 352.0], '2.4': [768.0, 320.0] |
| | } |
| | ASPECT_RATIO_RANDOM_CROP_PROB = [ |
| | 1, 2, |
| | 4, 4, 4, 4, |
| | 8, 8, 8, |
| | 4, 4, 4, 4, |
| | 2, 1 |
| | ] |
| | ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB) |
| |
|
| | def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512): |
| | aspect_ratio = height / width |
| | closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) |
| | return ratios[closest_ratio], float(closest_ratio) |
| |
|
| | def get_image_size_without_loading(path): |
| | with Image.open(path) as img: |
| | return img.size |
| |
|
| | class RandomSampler(Sampler[int]): |
| | r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. |
| | |
| | If with replacement, then user can specify :attr:`num_samples` to draw. |
| | |
| | Args: |
| | data_source (Dataset): dataset to sample from |
| | replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` |
| | num_samples (int): number of samples to draw, default=`len(dataset)`. |
| | generator (Generator): Generator used in sampling. |
| | """ |
| |
|
| | data_source: Sized |
| | replacement: bool |
| |
|
| | def __init__(self, data_source: Sized, replacement: bool = False, |
| | num_samples: Optional[int] = None, generator=None) -> None: |
| | self.data_source = data_source |
| | self.replacement = replacement |
| | self._num_samples = num_samples |
| | self.generator = generator |
| | self._pos_start = 0 |
| |
|
| | if not isinstance(self.replacement, bool): |
| | raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") |
| |
|
| | if not isinstance(self.num_samples, int) or self.num_samples <= 0: |
| | raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") |
| |
|
| | @property |
| | def num_samples(self) -> int: |
| | |
| | if self._num_samples is None: |
| | return len(self.data_source) |
| | return self._num_samples |
| |
|
| | def __iter__(self) -> Iterator[int]: |
| | n = len(self.data_source) |
| | if self.generator is None: |
| | seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
| | generator = torch.Generator() |
| | generator.manual_seed(seed) |
| | else: |
| | generator = self.generator |
| |
|
| | if self.replacement: |
| | for _ in range(self.num_samples // 32): |
| | yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() |
| | yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() |
| | else: |
| | for _ in range(self.num_samples // n): |
| | xx = torch.randperm(n, generator=generator).tolist() |
| | if self._pos_start >= n: |
| | self._pos_start = 0 |
| | for idx in range(self._pos_start, n): |
| | yield xx[idx] |
| | self._pos_start = (self._pos_start + 1) % n |
| | self._pos_start = 0 |
| | yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] |
| |
|
| | def __len__(self) -> int: |
| | return self.num_samples |
| |
|
| | class AspectRatioBatchImageSampler(BatchSampler): |
| | """A sampler wrapper for grouping images with similar aspect ratio into a same batch. |
| | |
| | Args: |
| | sampler (Sampler): Base sampler. |
| | dataset (Dataset): Dataset providing data information. |
| | batch_size (int): Size of mini-batch. |
| | drop_last (bool): If ``True``, the sampler will drop the last batch if |
| | its size would be less than ``batch_size``. |
| | aspect_ratios (dict): The predefined aspect ratios. |
| | """ |
| | def __init__( |
| | self, |
| | sampler: Sampler, |
| | dataset: Dataset, |
| | batch_size: int, |
| | train_folder: str = None, |
| | aspect_ratios: dict = ASPECT_RATIO_512, |
| | drop_last: bool = False, |
| | config=None, |
| | **kwargs |
| | ) -> None: |
| | if not isinstance(sampler, Sampler): |
| | raise TypeError('sampler should be an instance of ``Sampler``, ' |
| | f'but got {sampler}') |
| | if not isinstance(batch_size, int) or batch_size <= 0: |
| | raise ValueError('batch_size should be a positive integer value, ' |
| | f'but got batch_size={batch_size}') |
| | self.sampler = sampler |
| | self.dataset = dataset |
| | self.train_folder = train_folder |
| | self.batch_size = batch_size |
| | self.aspect_ratios = aspect_ratios |
| | self.drop_last = drop_last |
| | self.config = config |
| | |
| | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} |
| | |
| | self.current_available_bucket_keys = list(aspect_ratios.keys()) |
| |
|
| | def __iter__(self): |
| | for idx in self.sampler: |
| | try: |
| | image_dict = self.dataset[idx] |
| |
|
| | width, height = image_dict.get("width", None), image_dict.get("height", None) |
| | if width is None or height is None: |
| | image_id, name = image_dict['file_path'], image_dict['text'] |
| | if self.train_folder is None: |
| | image_dir = image_id |
| | else: |
| | image_dir = os.path.join(self.train_folder, image_id) |
| |
|
| | width, height = get_image_size_without_loading(image_dir) |
| |
|
| | ratio = height / width |
| | else: |
| | height = int(height) |
| | width = int(width) |
| | ratio = height / width |
| | except Exception as e: |
| | print(e) |
| | continue |
| | |
| | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) |
| | if closest_ratio not in self.current_available_bucket_keys: |
| | continue |
| | bucket = self._aspect_ratio_buckets[closest_ratio] |
| | bucket.append(idx) |
| | |
| | if len(bucket) == self.batch_size: |
| | yield bucket[:] |
| | del bucket[:] |
| |
|
| | class AspectRatioBatchSampler(BatchSampler): |
| | """A sampler wrapper for grouping images with similar aspect ratio into a same batch. |
| | |
| | Args: |
| | sampler (Sampler): Base sampler. |
| | dataset (Dataset): Dataset providing data information. |
| | batch_size (int): Size of mini-batch. |
| | drop_last (bool): If ``True``, the sampler will drop the last batch if |
| | its size would be less than ``batch_size``. |
| | aspect_ratios (dict): The predefined aspect ratios. |
| | """ |
| | def __init__( |
| | self, |
| | sampler: Sampler, |
| | dataset: Dataset, |
| | batch_size: int, |
| | video_folder: str = None, |
| | train_data_format: str = "webvid", |
| | aspect_ratios: dict = ASPECT_RATIO_512, |
| | drop_last: bool = False, |
| | config=None, |
| | **kwargs |
| | ) -> None: |
| | if not isinstance(sampler, Sampler): |
| | raise TypeError('sampler should be an instance of ``Sampler``, ' |
| | f'but got {sampler}') |
| | if not isinstance(batch_size, int) or batch_size <= 0: |
| | raise ValueError('batch_size should be a positive integer value, ' |
| | f'but got batch_size={batch_size}') |
| | self.sampler = sampler |
| | self.dataset = dataset |
| | self.video_folder = video_folder |
| | self.train_data_format = train_data_format |
| | self.batch_size = batch_size |
| | self.aspect_ratios = aspect_ratios |
| | self.drop_last = drop_last |
| | self.config = config |
| | |
| | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} |
| | |
| | self.current_available_bucket_keys = list(aspect_ratios.keys()) |
| |
|
| | def __iter__(self): |
| | for idx in self.sampler: |
| | try: |
| | video_dict = self.dataset[idx] |
| | width, more = video_dict.get("width", None), video_dict.get("height", None) |
| |
|
| | if width is None or height is None: |
| | if self.train_data_format == "normal": |
| | video_id, name = video_dict['file_path'], video_dict['text'] |
| | if self.video_folder is None: |
| | video_dir = video_id |
| | else: |
| | video_dir = os.path.join(self.video_folder, video_id) |
| | else: |
| | videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] |
| | video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") |
| | cap = cv2.VideoCapture(video_dir) |
| |
|
| | |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| | |
| | ratio = height / width |
| | else: |
| | height = int(height) |
| | width = int(width) |
| | ratio = height / width |
| | except Exception as e: |
| | print(e, self.dataset[idx], "This item is error, please check it.") |
| | continue |
| | |
| | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) |
| | if closest_ratio not in self.current_available_bucket_keys: |
| | continue |
| | bucket = self._aspect_ratio_buckets[closest_ratio] |
| | bucket.append(idx) |
| | |
| | if len(bucket) == self.batch_size: |
| | yield bucket[:] |
| | del bucket[:] |
| |
|
| | class AspectRatioBatchImageVideoSampler(BatchSampler): |
| | """A sampler wrapper for grouping images with similar aspect ratio into a same batch. |
| | |
| | Args: |
| | sampler (Sampler): Base sampler. |
| | dataset (Dataset): Dataset providing data information. |
| | batch_size (int): Size of mini-batch. |
| | drop_last (bool): If ``True``, the sampler will drop the last batch if |
| | its size would be less than ``batch_size``. |
| | aspect_ratios (dict): The predefined aspect ratios. |
| | """ |
| |
|
| | def __init__(self, |
| | sampler: Sampler, |
| | dataset: Dataset, |
| | batch_size: int, |
| | train_folder: str = None, |
| | aspect_ratios: dict = ASPECT_RATIO_512, |
| | drop_last: bool = False |
| | ) -> None: |
| | if not isinstance(sampler, Sampler): |
| | raise TypeError('sampler should be an instance of ``Sampler``, ' |
| | f'but got {sampler}') |
| | if not isinstance(batch_size, int) or batch_size <= 0: |
| | raise ValueError('batch_size should be a positive integer value, ' |
| | f'but got batch_size={batch_size}') |
| | self.sampler = sampler |
| | self.dataset = dataset |
| | self.train_folder = train_folder |
| | self.batch_size = batch_size |
| | self.aspect_ratios = aspect_ratios |
| | self.drop_last = drop_last |
| |
|
| | |
| | self.current_available_bucket_keys = list(aspect_ratios.keys()) |
| | self.bucket = { |
| | 'image':{ratio: [] for ratio in aspect_ratios}, |
| | 'video':{ratio: [] for ratio in aspect_ratios} |
| | } |
| |
|
| | def __iter__(self): |
| | for idx in self.sampler: |
| | content_type = self.dataset[idx].get('type', 'video') |
| | |
| | try: |
| | data_dict = self.dataset[idx] |
| | width, height = data_dict.get("width", None), data_dict.get("height", None) |
| |
|
| | if width is None or height is None: |
| | if content_type == 'image': |
| | |
| | image_id = data_dict.get('file_path', '') |
| | if self.train_folder is None: |
| | image_dir = image_id |
| | else: |
| | image_dir = os.path.join(self.train_folder, image_id) |
| | width, height = get_image_size_without_loading(image_dir) |
| | else: |
| | |
| | video_id = ( |
| | data_dict.get('original_video') |
| | or data_dict.get('edited_video') |
| | or data_dict.get('file_path') |
| | ) |
| | if video_id is None: |
| | raise ValueError(f"No valid video path found in dataset item: {data_dict}") |
| | if self.train_folder is None: |
| | video_dir = video_id |
| | else: |
| | video_dir = os.path.join(self.train_folder, video_id) |
| | cap = cv2.VideoCapture(video_dir) |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| | cap.release() |
| | if width == 0 or height == 0: |
| | raise ValueError(f"Invalid video size for {video_dir}: {width}x{height}") |
| | else: |
| | height = int(height) |
| | width = int(width) |
| | |
| | ratio = height / width |
| | |
| | except Exception as e: |
| | print(e, self.dataset[idx], "This item is error, please check it.") |
| | continue |
| | |
| | |
| | closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) |
| | if closest_ratio not in self.current_available_bucket_keys: |
| | continue |
| | |
| | |
| | bucket = self.bucket[content_type][closest_ratio] |
| | bucket.append(idx) |
| | |
| | |
| | if len(bucket) == self.batch_size: |
| | yield bucket[:] |
| | del bucket[:] |