New BERT dataloader (#5881)

* One file == One topic

* update test

* new dataloader

* update train script

* get index is faster
This commit is contained in:
Elias Wahl 2024-08-02 21:12:23 +02:00 committed by GitHub
commit 4a114756f6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 89 additions and 102 deletions

View file

@ -1,13 +1,14 @@
import os, random, pickle, functools, itertools
from typing import List, Tuple
import os, random, pickle, queue
from typing import List
from pathlib import Path
import numpy as np
from PIL import Image
from tqdm import tqdm
from tinygrad import dtypes, Tensor
from tinygrad.helpers import getenv, prod, Context, round_up
from collections import deque
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count, Pool
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
### ResNet
class MyQueue:
def __init__(self, multiple_readers=True, multiple_writers=True):
@ -165,10 +166,7 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_fir
# happens with BENCHMARK set
pass
@functools.lru_cache(maxsize=128)
def load_bert_file(fn:str) -> List[dict]:
with open(fn, "rb") as f: data = pickle.load(f)
return data
### BERT
def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
return {
@ -181,72 +179,75 @@ def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32),
}
def shuffle_parts(file_paths: List[str]) -> List[str]:
parts = {}
for f in file_paths:
part = Path(f).stem.split('_')[0]
if part not in parts: parts[part] = []
parts[part].append(f)
part_ids = list(parts.keys())
random.shuffle(part_ids)
def load_file(file: str):
with open(file, "rb") as f:
return pickle.load(f)
shuffled_files = []
for p in part_ids:
parts[p].sort(key=lambda x: int(Path(x).stem.split('_')[1]))
shuffled_files.extend(parts[p])
return shuffled_files
class InterleavedDataset:
def __init__(self, files:List[str], cycle_length:int):
self.dataset = files
self.cycle_length = cycle_length
self.queues = [queue.Queue() for _ in range(self.cycle_length)]
for i in range(len(self.queues)): self.queues[i].queue.extend(load_file(self.dataset.pop(0)))
self.queue_pointer = len(self.queues) - 1
def random_sample(data: List[str]):
index = random.randint(0, len(data) - 1)
selected_sample = data[index]
return selected_sample, index
def get(self):
# Round-robin across queues
try:
self.advance()
return self.queues[self.queue_pointer].get_nowait()
except queue.Empty:
self.fill(self.queue_pointer)
return self.get()
def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]:
data = load_bert_file(file_and_offset[0])
return data[file_and_offset[1]]
def advance(self):
self.queue_pointer = (self.queue_pointer + 1) % self.cycle_length
def fill(self, queue_index: int):
try:
file = self.dataset.pop(0)
except IndexError:
return
self.queues[queue_index].queue.extend(load_file(file))
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
def batch_load_train_bert(BS:int, start_step:int = 0):
def batch_load_train_bert(BS:int):
from extra.datasets.wikipedia import get_wiki_train_files
files = shuffle_parts(get_wiki_train_files())
dataset = []
for f in tqdm(files, desc="Building dataset"):
lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))]
dataset.extend(lists)
dataset = dataset[start_step:]
active_set = deque(dataset[:1000])
remaining_set = deque(dataset[1000:])
fs = sorted(get_wiki_train_files())
train_files = []
while fs: # TF shuffle
random.shuffle(fs)
train_files.append(fs.pop(0))
while dataset:
blob = []
cycle_length = min(getenv("NUM_CPU_THREADS", min(os.cpu_count(), 8)), len(train_files))
assert cycle_length > 0, "cycle_length must be greater than 0"
dataset = InterleavedDataset(train_files, cycle_length)
buffer = [dataset.get() for _ in range(1000)]
while True:
batch = []
for _ in range(BS):
if active_set:
index = random.randint(0, len(active_set) - 1)
sample = active_set[index]
active_set.remove(sample)
blob.append(sample)
if remaining_set:
active_set.append(remaining_set.popleft())
yield process_batch_bert([load_datasample(sample) for sample in blob])
index = random.randint(0, 999)
batch.append(buffer[index])
buffer[index] = dataset.get()
yield process_batch_bert(batch)
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 416
def batch_load_val_bert(BS:int):
from extra.datasets.wikipedia import get_wiki_val_files
files = get_wiki_val_files()
dataset = list(itertools.chain.from_iterable([load_bert_file(f) for f in files]))
file = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki") / "eval.pkl"
dataset = load_file(file)
idx = 0
while True:
start_idx = (idx * BS) % len(dataset)
end_idx = ((idx + 1) * BS) % len(dataset)
if start_idx < end_idx:
yield process_batch_bert(dataset[start_idx:end_idx])
yield process_batch_bert(dataset[start_idx:end_idx])
else: # wrap around the end to the beginning of the dataset
yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
idx += 1
### UNET3D
def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tensor, Y:Tensor):
from extra.datasets.kits19 import rand_balanced_crop, rand_flip, random_brightness_augmentation, gaussian_noise

View file

@ -520,7 +520,9 @@ def train_bert():
if not INITMLPERF:
eval_it = iter(batch_load_val_bert(EVAL_BS))
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
for _ in range(start_step): next(train_it) # Fast forward
step_times = []
# ** train loop **

View file

@ -351,12 +351,14 @@ def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict:
def process_part(part:int):
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)):
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
pickle.dump(feature_batch, f)
os.makedirs(BASEDIR / "train", exist_ok=True)
def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
if os.path.exists(BASEDIR / f"train/{str(part)}.pkl"): return
features = get_features_from_part(tokenizer, val=False, part=part)
with open(BASEDIR / f"train/{str(part)}.pkl", "wb") as f:
pickle.dump(features, f)
def get_features_from_part(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
rng = random.Random(getenv('RANDOM_SEED', 12345))
if val:
@ -368,25 +370,16 @@ def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dic
tqdm.write(f"Picking 10000 samples")
pick_ratio = len(instances) / 10000
picks = [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
for batch in range(10):
yield picks[batch*1000:(batch+1)*1000]
return [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
else:
documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500")
instances = get_instances(rng, tokenizer, documents)
while len(instances) > 0:
batch_size = min(1000, len(instances)) # We batch 1000 samples to one file
batch = instances[:batch_size]
del instances[:batch_size]
yield [instance_to_features(instance, tokenizer) for instance in batch]
return [instance_to_features(instance, tokenizer) for instance in instances]
##################### Load files #####################
def get_wiki_val_files(): return sorted(list((BASEDIR / "eval/").glob("*.pkl")))
@diskcache
def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*/*.pkl")))
def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*.pkl")))
if __name__ == "__main__":
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
@ -394,18 +387,12 @@ if __name__ == "__main__":
assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all"
if sys.argv[1] == "pre-eval": # Generate 10000 eval samples
os.makedirs(BASEDIR / "eval", exist_ok=True)
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=True)), total=10):
with open(BASEDIR / f"eval/{i}.pkl", "wb") as f:
pickle.dump(feature_batch, f)
with open(BASEDIR / "eval.pkl", "wb") as f:
pickle.dump(get_features_from_part(tokenizer, val=True), f)
elif sys.argv[1] == "pre-train":
os.makedirs(BASEDIR / "train", exist_ok=True)
if sys.argv[2] == "all": # Use all 500 parts for training generation
process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', min(os.cpu_count(), 32)), chunksize=1)
else: # Use a specific part for training generation
part = int(sys.argv[2])
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))):
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
pickle.dump(feature_batch, f)
part = sys.argv[2]
print(f"Processing part {part}...")
process_part(int(part))

View file

@ -10,15 +10,15 @@
# Command: python3 pick_eval_samples.py --input_tfrecord=/path/to/eval.tfrecord --output_tfrecord=/path/to/output_eval.tfrecord
# 3. Run `wikipedia.py` to preprocess the data with tinygrad (Use python > 3.7)
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-train XXX (NOTE: part number needs to match part of step 2)
# This will output to /path/to/basedir/train/XXX
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-train X (NOTE: part number needs to match part of step 2)
# This will output to /path/to/basedir/train/X.pkl
#
# 3.1 For eval:
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-eval
# This will output to /path/to/basedir/eval
# This will output to /path/to/basedir/eval.pkl
# 4. Run this script to verify the correctness of the preprocessing script for specific part
# Command: python3 external_test_preprocessing_part.py --preprocessed_part_dir=/path/to/basedir/part --tf_records=/path/to/output.tfrecord
# Command: python3 external_test_preprocessing_part.py --preprocessed_part=/path/to/basedir/train/X.pkl --tf_records=/path/to/output.tfrecord
import os, argparse, pickle
from tqdm import tqdm
@ -51,23 +51,20 @@ def load_dataset(file_path, max_seq_length=512, max_predictions_per_seq=76):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify the correctness of the preprocessing script for specific part",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--preprocessed_part_dir", type=str, default=None,
help="Path to dir with preprocessed samples from `wikipedia.py`")
parser.add_argument("--preprocessed_part", type=str, default=None,
help="Path to preprocessed samples file from `wikipedia.py`")
parser.add_argument("--tf_records", type=str, default=None,
help="Path to TFRecords file from `create_pretraining_data.py` (Reference implementation)")
parser.add_argument("--max_seq_length", type=int, default=512, help="Max sequence length. For MLPerf keep it as 512")
parser.add_argument("--max_predictions_per_seq", type=int, default=76, help="Max predictions per sequence. For MLPerf keep it as 76")
parser.add_argument("--max_seq_length", type=int, default=512, help="Max sequence length. For MLPerf keep it at 512")
parser.add_argument("--max_predictions_per_seq", type=int, default=76, help="Max predictions per sequence. For MLPerf keep it at 76")
parser.add_argument("--is_eval", type=bool, default=False, help="Whether to run eval or train preprocessing")
args = parser.parse_args()
assert os.path.isdir(args.preprocessed_part_dir), f"The specified directory {args.preprocessed_part_dir} does not exist."
assert os.path.isfile(args.preprocessed_part), f"The specified file {args.preprocessed_part} does not exist."
assert os.path.isfile(args.tf_records), f"The specified TFRecords file {args.tf_records} does not exist."
preprocessed_samples = []
for file_name in sorted(os.listdir(args.preprocessed_part_dir), key=lambda x: int(x.split("_")[1]) if not args.is_eval else int(x.split(".")[0])): # 0_3.pkl -> 3 # noqa: E501
with open(os.path.join(args.preprocessed_part_dir, file_name), 'rb') as f:
samples = pickle.load(f)
preprocessed_samples.extend(samples)
with open(args.preprocessed_part, 'rb') as f:
preprocessed_samples = pickle.load(f)
dataset = load_dataset(args.tf_records, args.max_seq_length, args.max_predictions_per_seq)
tf_record_count = sum(1 for _ in dataset)