llama3: move dl to numpy & jit more (#14677)

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
wozeparrot 2026-02-10 18:16:40 -08:00 committed by GitHub
commit a60220bed9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 11 additions and 8 deletions

View file

@ -552,7 +552,7 @@ class BinIdxDataset:
version, = struct.unpack("<Q", self.idx.read(8))
assert version == 1, "unsupported index version"
dtype_code, = struct.unpack("<B", self.idx.read(1))
self.dtype = {1:dtypes.uint8, 2:dtypes.int8, 3:dtypes.int16, 4:dtypes.int32, 5:dtypes.int64, 6:dtypes.float64, 7:dtypes.double, 8:dtypes.uint16}[dtype_code]
self.dtype = {1:np.dtype(np.uint8), 2:np.dtype(np.int8), 3:np.dtype(np.int16), 4:np.dtype(np.int32), 5:np.dtype(np.int64), 6:np.dtype(np.float64), 7:np.dtype(np.double), 8:np.dtype(np.uint16)}[dtype_code]
self.count, = struct.unpack("<Q", self.idx.read(8))
doc_count, = struct.unpack("<Q", self.idx.read(8))
@ -569,7 +569,7 @@ class BinIdxDataset:
self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64).numpy()
# bin file
self.bin_t = Tensor(base_path.with_name(f"{base_path.name}.bin"))
self.bin_t = Tensor(base_path.with_name(f"{base_path.name}.bin")).numpy()
def _index(self, idx) -> tuple[int, int]:
return int(self.pointers[idx]), int(self.sizes[idx])
@ -578,7 +578,7 @@ class BinIdxDataset:
ptr, size = self._index(idx)
if length is None: length = size - offset
ptr += offset * self.dtype.itemsize
return self.bin_t[ptr:ptr+length*self.dtype.itemsize].bitcast(self.dtype).to(None)
return self.bin_t[ptr:ptr+length*self.dtype.itemsize].view(self.dtype)
# https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/datasets.html
class GPTDataset:
@ -637,7 +637,7 @@ class GPTDataset:
sample_parts.append(self.indexed_dataset.get(int(self.doc_idx[i]), offset=int(offset), length=length))
# concat all parts
text = Tensor.cat(*sample_parts)
text = np.concatenate(sample_parts, axis=0)
return text
@ -780,7 +780,8 @@ def get_llama3_dataset(samples:int, seqlen:int, base_dir:Path, seed:int=0, val:b
def iterate_llama3_dataset(dataset:BlendedGPTDataset, bs:int):
for b in range(math.ceil(dataset.samples / bs)):
batch = [dataset.get(b * bs + i) for i in range(bs)]
yield Tensor.stack(batch, dim=0)
stacked = np.stack(batch, axis=0)
yield Tensor(stacked, device="NPY")
def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True, small:bool=False):
return iterate_llama3_dataset(get_llama3_dataset(samples, seqlen, base_dir, seed, val, small), bs)

View file

@ -1390,6 +1390,7 @@ def train_llama3():
@TinyJit
def minibatch(tokens:Tensor):
tokens = tokens.to(None)
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
tokens = tokens.shard(device, 0)
@ -1401,7 +1402,7 @@ def train_llama3():
loss.backward()
assert all(p.grad is g for p,g in zip(optim.params, grads))
Tensor.realize(loss, *grads)
return loss
return loss.flatten().float().to("CPU")
@TinyJit
def optim_step():
@ -1428,11 +1429,12 @@ def train_llama3():
lr = optim.lr
Tensor.realize(lr, *grads)
return lr
return lr.float().to("CPU")
@TinyJit
@Tensor.train(False)
def eval_step(tokens:Tensor):
tokens = tokens.to(None)
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
tokens = tokens.shard(device, 0)
@ -1441,7 +1443,7 @@ def train_llama3():
tokens = tokens.shard(device)
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
return loss.flatten().float()
return loss.flatten().float().to("CPU")
# ** data iters **
def fake_data(bs, samples):