| |
| import datasets |
| import importlib |
| import tqdm |
| import transformers |
| import typer |
|
|
| def load_config(config_file: str): |
| spec = importlib.util.spec_from_file_location("config", config_file) |
| config_module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(config_module) |
| return config_module.sources, config_module.tokenizer_name, config_module.prefix |
|
|
| def tokenize(batch: dict): |
| if tokenizer: |
| return {"num_tokens": tokenizer(batch["text"], padding="do_not_pad", return_length=True)["length"]} |
| return {"num_tokens": 0} |
|
|
| def shard_indices(shard_index): |
| if not isinstance(shard_index, list): |
| shard_index = [shard_index] |
| return shard_index |
|
|
| def preprocess_shard(ds: datasets.Dataset, num_shards: int, index: int, num_proc: int): |
| shard = ds.shard(num_shards=num_shards, index=index, contiguous=True) |
| shard = shard.flatten_indices() |
| shard = shard.map(tokenize, batched=True, batch_size=1000, num_proc=num_proc) |
| return shard |
|
|
| def preprocess_subset(weights: dict, subsets: list, source: str, src_info: dict, dc: datasets.DownloadConfig, num_proc: int): |
| for key, frac in tqdm.tqdm(weights.items(), desc="Loading train subsets"): |
| uri_template = src_info["uri"] |
| print(f" Loading subset: {key} with fraction 1/{frac} from {uri_template.format(key=key)}") |
| ds = datasets.load_dataset( |
| src_info["format"], |
| data_files=uri_template.format(key=key), |
| split="train", |
| download_config=dc, |
| ) |
| ds = ds.select_columns(["text"]) |
| ds = ds.add_column("source", [source] * len(ds)) |
| ds = ds.add_column("subset", [key] * len(ds)) |
| ds = ds.shuffle(seed=42) |
| dss = [preprocess_shard(ds, int(src_info["shards"]/frac), i, num_proc) for i in shard_indices(src_info["shard_index"])] |
| ds = datasets.concatenate_datasets(dss) |
| ds = ds.cast_column("text", datasets.Value("large_string")) |
| print(f" Finished preprocessing subset: {key} with {sum(ds['num_tokens'])} tokens") |
| subsets.append(ds) |
|
|
| def main( |
| config_file: str, |
| num_proc: int = 96, |
| max_retries: int = 10, |
| ): |
| sources, tokenizer_name, prefix = load_config(config_file) |
| global tokenizer |
| tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) if tokenizer_name else None |
| dc = datasets.DownloadConfig(num_proc=num_proc, max_retries=max_retries) |
| train_subsets = [] |
| test_subsets = [] |
| file_name = f"{prefix}-" |
| for source, src_info in sources.items(): |
| print(f"Processing source: {source}") |
| shard_index = src_info["shard_index"] |
| if not isinstance(shard_index, list): |
| shard_index = [shard_index] |
| file_name += f"{source}-{'_'.join(str(s) for s in shard_index)}-of-{src_info['shards']}-" |
| preprocess_subset(src_info["train"], train_subsets, source, src_info, dc, num_proc) |
| preprocess_subset(src_info["test"], test_subsets, source, src_info, dc, num_proc) |
| print("Concatenating train subsets") |
| final_train = datasets.concatenate_datasets(train_subsets) |
| print("Shuffling final train dataset") |
| final_train = final_train.shuffle(seed=42) |
| print("Flattening final train dataset") |
| final_train = final_train.flatten_indices() |
| print("Concatenating test subsets") |
| final_test = datasets.concatenate_datasets(test_subsets) |
| print("Shuffling final test dataset") |
| final_test = final_test.shuffle(seed=42) |
| print("Flattening final test dataset") |
| final_test = final_test.flatten_indices() |
| test_file = f"{file_name}test/{file_name}test.parquet" |
| print(f"Writing final test dataset with {sum(final_test['num_tokens'])} tokens to {test_file}") |
| final_test.to_parquet(test_file) |
| train_file = f"{file_name}train/{file_name}train.parquet" |
| print(f"Writing final train dataset with {sum(final_train['num_tokens'])} tokens to {train_file}") |
| final_train.to_parquet(train_file) |
|
|
| if __name__ == "__main__": |
| typer.run(main) |
|
|