mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
New BERT dataloader (#5881)
* One file == One topic * update test * new dataloader * update train script * get index is faster
This commit is contained in:
parent
2777784b91
commit
4a114756f6
4 changed files with 89 additions and 102 deletions
|
|
@ -1,13 +1,14 @@
|
|||
import os, random, pickle, functools, itertools
|
||||
from typing import List, Tuple
|
||||
import os, random, pickle, queue
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import getenv, prod, Context, round_up
|
||||
from collections import deque
|
||||
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count, Pool
|
||||
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
|
||||
|
||||
### ResNet
|
||||
|
||||
class MyQueue:
|
||||
def __init__(self, multiple_readers=True, multiple_writers=True):
|
||||
|
|
@ -165,10 +166,7 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_fir
|
|||
# happens with BENCHMARK set
|
||||
pass
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def load_bert_file(fn:str) -> List[dict]:
|
||||
with open(fn, "rb") as f: data = pickle.load(f)
|
||||
return data
|
||||
### BERT
|
||||
|
||||
def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
|
||||
return {
|
||||
|
|
@ -181,72 +179,75 @@ def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
|
|||
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.float32),
|
||||
}
|
||||
|
||||
def shuffle_parts(file_paths: List[str]) -> List[str]:
|
||||
parts = {}
|
||||
for f in file_paths:
|
||||
part = Path(f).stem.split('_')[0]
|
||||
if part not in parts: parts[part] = []
|
||||
parts[part].append(f)
|
||||
|
||||
part_ids = list(parts.keys())
|
||||
random.shuffle(part_ids)
|
||||
def load_file(file: str):
|
||||
with open(file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
shuffled_files = []
|
||||
for p in part_ids:
|
||||
parts[p].sort(key=lambda x: int(Path(x).stem.split('_')[1]))
|
||||
shuffled_files.extend(parts[p])
|
||||
return shuffled_files
|
||||
class InterleavedDataset:
|
||||
def __init__(self, files:List[str], cycle_length:int):
|
||||
self.dataset = files
|
||||
self.cycle_length = cycle_length
|
||||
self.queues = [queue.Queue() for _ in range(self.cycle_length)]
|
||||
for i in range(len(self.queues)): self.queues[i].queue.extend(load_file(self.dataset.pop(0)))
|
||||
self.queue_pointer = len(self.queues) - 1
|
||||
|
||||
def random_sample(data: List[str]):
|
||||
index = random.randint(0, len(data) - 1)
|
||||
selected_sample = data[index]
|
||||
return selected_sample, index
|
||||
def get(self):
|
||||
# Round-robin across queues
|
||||
try:
|
||||
self.advance()
|
||||
return self.queues[self.queue_pointer].get_nowait()
|
||||
except queue.Empty:
|
||||
self.fill(self.queue_pointer)
|
||||
return self.get()
|
||||
|
||||
def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]:
|
||||
data = load_bert_file(file_and_offset[0])
|
||||
return data[file_and_offset[1]]
|
||||
def advance(self):
|
||||
self.queue_pointer = (self.queue_pointer + 1) % self.cycle_length
|
||||
|
||||
def fill(self, queue_index: int):
|
||||
try:
|
||||
file = self.dataset.pop(0)
|
||||
except IndexError:
|
||||
return
|
||||
self.queues[queue_index].queue.extend(load_file(file))
|
||||
|
||||
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394
|
||||
def batch_load_train_bert(BS:int, start_step:int = 0):
|
||||
def batch_load_train_bert(BS:int):
|
||||
from extra.datasets.wikipedia import get_wiki_train_files
|
||||
files = shuffle_parts(get_wiki_train_files())
|
||||
dataset = []
|
||||
for f in tqdm(files, desc="Building dataset"):
|
||||
lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))]
|
||||
dataset.extend(lists)
|
||||
|
||||
dataset = dataset[start_step:]
|
||||
|
||||
active_set = deque(dataset[:1000])
|
||||
remaining_set = deque(dataset[1000:])
|
||||
fs = sorted(get_wiki_train_files())
|
||||
train_files = []
|
||||
while fs: # TF shuffle
|
||||
random.shuffle(fs)
|
||||
train_files.append(fs.pop(0))
|
||||
|
||||
while dataset:
|
||||
blob = []
|
||||
cycle_length = min(getenv("NUM_CPU_THREADS", min(os.cpu_count(), 8)), len(train_files))
|
||||
assert cycle_length > 0, "cycle_length must be greater than 0"
|
||||
|
||||
dataset = InterleavedDataset(train_files, cycle_length)
|
||||
buffer = [dataset.get() for _ in range(1000)]
|
||||
while True:
|
||||
batch = []
|
||||
for _ in range(BS):
|
||||
if active_set:
|
||||
index = random.randint(0, len(active_set) - 1)
|
||||
sample = active_set[index]
|
||||
active_set.remove(sample)
|
||||
blob.append(sample)
|
||||
if remaining_set:
|
||||
active_set.append(remaining_set.popleft())
|
||||
yield process_batch_bert([load_datasample(sample) for sample in blob])
|
||||
index = random.randint(0, 999)
|
||||
batch.append(buffer[index])
|
||||
buffer[index] = dataset.get()
|
||||
yield process_batch_bert(batch)
|
||||
|
||||
# Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 416
|
||||
def batch_load_val_bert(BS:int):
|
||||
from extra.datasets.wikipedia import get_wiki_val_files
|
||||
files = get_wiki_val_files()
|
||||
dataset = list(itertools.chain.from_iterable([load_bert_file(f) for f in files]))
|
||||
file = getenv("BASEDIR", Path(__file__).parent.parents[1] / "extra" / "datasets" / "wiki") / "eval.pkl"
|
||||
dataset = load_file(file)
|
||||
idx = 0
|
||||
while True:
|
||||
start_idx = (idx * BS) % len(dataset)
|
||||
end_idx = ((idx + 1) * BS) % len(dataset)
|
||||
if start_idx < end_idx:
|
||||
yield process_batch_bert(dataset[start_idx:end_idx])
|
||||
yield process_batch_bert(dataset[start_idx:end_idx])
|
||||
else: # wrap around the end to the beginning of the dataset
|
||||
yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
|
||||
yield process_batch_bert(dataset[start_idx:] + dataset[:end_idx])
|
||||
idx += 1
|
||||
|
||||
### UNET3D
|
||||
|
||||
def load_unet3d_data(preprocessed_dataset_dir, seed, queue_in, queue_out, X:Tensor, Y:Tensor):
|
||||
from extra.datasets.kits19 import rand_balanced_crop, rand_flip, random_brightness_augmentation, gaussian_noise
|
||||
|
||||
|
|
|
|||
|
|
@ -520,7 +520,9 @@ def train_bert():
|
|||
|
||||
if not INITMLPERF:
|
||||
eval_it = iter(batch_load_val_bert(EVAL_BS))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS, start_step), initial=start_step, total=train_steps, disable=BENCHMARK))
|
||||
train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK))
|
||||
for _ in range(start_step): next(train_it) # Fast forward
|
||||
|
||||
|
||||
step_times = []
|
||||
# ** train loop **
|
||||
|
|
|
|||
|
|
@ -351,12 +351,14 @@ def instance_to_features(instance:dict, tokenizer:Tokenizer) -> dict:
|
|||
|
||||
def process_part(part:int):
|
||||
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
|
||||
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
|
||||
for i, feature_batch in enumerate(process_iterate(tokenizer, val=False, part=part)):
|
||||
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
|
||||
pickle.dump(feature_batch, f)
|
||||
os.makedirs(BASEDIR / "train", exist_ok=True)
|
||||
|
||||
def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
|
||||
if os.path.exists(BASEDIR / f"train/{str(part)}.pkl"): return
|
||||
features = get_features_from_part(tokenizer, val=False, part=part)
|
||||
with open(BASEDIR / f"train/{str(part)}.pkl", "wb") as f:
|
||||
pickle.dump(features, f)
|
||||
|
||||
def get_features_from_part(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dict]: # Convert raw text to masked NSP samples
|
||||
rng = random.Random(getenv('RANDOM_SEED', 12345))
|
||||
|
||||
if val:
|
||||
|
|
@ -368,25 +370,16 @@ def process_iterate(tokenizer:Tokenizer, val:bool=False, part:int=0) -> list[dic
|
|||
tqdm.write(f"Picking 10000 samples")
|
||||
|
||||
pick_ratio = len(instances) / 10000
|
||||
picks = [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
|
||||
for batch in range(10):
|
||||
yield picks[batch*1000:(batch+1)*1000]
|
||||
return [instance_to_features(instances[int(inst*pick_ratio)], tokenizer) for inst in range(10000)]
|
||||
else:
|
||||
documents = get_documents(rng, tokenizer, f"results4/part-{part:05d}-of-00500")
|
||||
instances = get_instances(rng, tokenizer, documents)
|
||||
|
||||
while len(instances) > 0:
|
||||
batch_size = min(1000, len(instances)) # We batch 1000 samples to one file
|
||||
batch = instances[:batch_size]
|
||||
del instances[:batch_size]
|
||||
yield [instance_to_features(instance, tokenizer) for instance in batch]
|
||||
return [instance_to_features(instance, tokenizer) for instance in instances]
|
||||
|
||||
##################### Load files #####################
|
||||
|
||||
def get_wiki_val_files(): return sorted(list((BASEDIR / "eval/").glob("*.pkl")))
|
||||
|
||||
@diskcache
|
||||
def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*/*.pkl")))
|
||||
def get_wiki_train_files(): return sorted(list((BASEDIR / "train/").glob("*.pkl")))
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = Tokenizer(getenv("BASEDIR", Path(__file__).parent / "wiki") / "vocab.txt")
|
||||
|
|
@ -394,18 +387,12 @@ if __name__ == "__main__":
|
|||
assert len(sys.argv) > 1, "Usage: python wikipedia.py pre-eval|pre-train [part]|all"
|
||||
|
||||
if sys.argv[1] == "pre-eval": # Generate 10000 eval samples
|
||||
os.makedirs(BASEDIR / "eval", exist_ok=True)
|
||||
|
||||
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=True)), total=10):
|
||||
with open(BASEDIR / f"eval/{i}.pkl", "wb") as f:
|
||||
pickle.dump(feature_batch, f)
|
||||
with open(BASEDIR / "eval.pkl", "wb") as f:
|
||||
pickle.dump(get_features_from_part(tokenizer, val=True), f)
|
||||
elif sys.argv[1] == "pre-train":
|
||||
os.makedirs(BASEDIR / "train", exist_ok=True)
|
||||
if sys.argv[2] == "all": # Use all 500 parts for training generation
|
||||
process_map(process_part, [part for part in range(500)], max_workers=getenv('NUM_WORKERS', min(os.cpu_count(), 32)), chunksize=1)
|
||||
else: # Use a specific part for training generation
|
||||
part = int(sys.argv[2])
|
||||
os.makedirs(BASEDIR / "train" / str(part), exist_ok=True)
|
||||
for i, feature_batch in tqdm(enumerate(process_iterate(tokenizer, val=False, part=part))):
|
||||
with open(BASEDIR / f"train/{str(part)}/{part}_{i}_of_{len(feature_batch)}.pkl", "wb") as f:
|
||||
pickle.dump(feature_batch, f)
|
||||
part = sys.argv[2]
|
||||
print(f"Processing part {part}...")
|
||||
process_part(int(part))
|
||||
|
|
|
|||
|
|
@ -10,15 +10,15 @@
|
|||
# Command: python3 pick_eval_samples.py --input_tfrecord=/path/to/eval.tfrecord --output_tfrecord=/path/to/output_eval.tfrecord
|
||||
|
||||
# 3. Run `wikipedia.py` to preprocess the data with tinygrad (Use python > 3.7)
|
||||
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-train XXX (NOTE: part number needs to match part of step 2)
|
||||
# This will output to /path/to/basedir/train/XXX
|
||||
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-train X (NOTE: part number needs to match part of step 2)
|
||||
# This will output to /path/to/basedir/train/X.pkl
|
||||
#
|
||||
# 3.1 For eval:
|
||||
# Command: BASEDIR=/path/to/basedir python3 wikipedia.py pre-eval
|
||||
# This will output to /path/to/basedir/eval
|
||||
# This will output to /path/to/basedir/eval.pkl
|
||||
|
||||
# 4. Run this script to verify the correctness of the preprocessing script for specific part
|
||||
# Command: python3 external_test_preprocessing_part.py --preprocessed_part_dir=/path/to/basedir/part --tf_records=/path/to/output.tfrecord
|
||||
# Command: python3 external_test_preprocessing_part.py --preprocessed_part=/path/to/basedir/train/X.pkl --tf_records=/path/to/output.tfrecord
|
||||
import os, argparse, pickle
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
@ -51,23 +51,20 @@ def load_dataset(file_path, max_seq_length=512, max_predictions_per_seq=76):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Verify the correctness of the preprocessing script for specific part",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--preprocessed_part_dir", type=str, default=None,
|
||||
help="Path to dir with preprocessed samples from `wikipedia.py`")
|
||||
parser.add_argument("--preprocessed_part", type=str, default=None,
|
||||
help="Path to preprocessed samples file from `wikipedia.py`")
|
||||
parser.add_argument("--tf_records", type=str, default=None,
|
||||
help="Path to TFRecords file from `create_pretraining_data.py` (Reference implementation)")
|
||||
parser.add_argument("--max_seq_length", type=int, default=512, help="Max sequence length. For MLPerf keep it as 512")
|
||||
parser.add_argument("--max_predictions_per_seq", type=int, default=76, help="Max predictions per sequence. For MLPerf keep it as 76")
|
||||
parser.add_argument("--max_seq_length", type=int, default=512, help="Max sequence length. For MLPerf keep it at 512")
|
||||
parser.add_argument("--max_predictions_per_seq", type=int, default=76, help="Max predictions per sequence. For MLPerf keep it at 76")
|
||||
parser.add_argument("--is_eval", type=bool, default=False, help="Whether to run eval or train preprocessing")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert os.path.isdir(args.preprocessed_part_dir), f"The specified directory {args.preprocessed_part_dir} does not exist."
|
||||
assert os.path.isfile(args.preprocessed_part), f"The specified file {args.preprocessed_part} does not exist."
|
||||
assert os.path.isfile(args.tf_records), f"The specified TFRecords file {args.tf_records} does not exist."
|
||||
|
||||
preprocessed_samples = []
|
||||
for file_name in sorted(os.listdir(args.preprocessed_part_dir), key=lambda x: int(x.split("_")[1]) if not args.is_eval else int(x.split(".")[0])): # 0_3.pkl -> 3 # noqa: E501
|
||||
with open(os.path.join(args.preprocessed_part_dir, file_name), 'rb') as f:
|
||||
samples = pickle.load(f)
|
||||
preprocessed_samples.extend(samples)
|
||||
with open(args.preprocessed_part, 'rb') as f:
|
||||
preprocessed_samples = pickle.load(f)
|
||||
|
||||
dataset = load_dataset(args.tf_records, args.max_seq_length, args.max_predictions_per_seq)
|
||||
tf_record_count = sum(1 for _ in dataset)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue