mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
llama3: move dl to numpy & jit more (#14677)
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
4565958792
commit
a60220bed9
2 changed files with 11 additions and 8 deletions
|
|
@ -552,7 +552,7 @@ class BinIdxDataset:
|
|||
version, = struct.unpack("<Q", self.idx.read(8))
|
||||
assert version == 1, "unsupported index version"
|
||||
dtype_code, = struct.unpack("<B", self.idx.read(1))
|
||||
self.dtype = {1:dtypes.uint8, 2:dtypes.int8, 3:dtypes.int16, 4:dtypes.int32, 5:dtypes.int64, 6:dtypes.float64, 7:dtypes.double, 8:dtypes.uint16}[dtype_code]
|
||||
self.dtype = {1:np.dtype(np.uint8), 2:np.dtype(np.int8), 3:np.dtype(np.int16), 4:np.dtype(np.int32), 5:np.dtype(np.int64), 6:np.dtype(np.float64), 7:np.dtype(np.double), 8:np.dtype(np.uint16)}[dtype_code]
|
||||
self.count, = struct.unpack("<Q", self.idx.read(8))
|
||||
doc_count, = struct.unpack("<Q", self.idx.read(8))
|
||||
|
||||
|
|
@ -569,7 +569,7 @@ class BinIdxDataset:
|
|||
self.doc_idx = self.idx_t[start:end].bitcast(dtypes.int64).numpy()
|
||||
|
||||
# bin file
|
||||
self.bin_t = Tensor(base_path.with_name(f"{base_path.name}.bin"))
|
||||
self.bin_t = Tensor(base_path.with_name(f"{base_path.name}.bin")).numpy()
|
||||
|
||||
def _index(self, idx) -> tuple[int, int]:
|
||||
return int(self.pointers[idx]), int(self.sizes[idx])
|
||||
|
|
@ -578,7 +578,7 @@ class BinIdxDataset:
|
|||
ptr, size = self._index(idx)
|
||||
if length is None: length = size - offset
|
||||
ptr += offset * self.dtype.itemsize
|
||||
return self.bin_t[ptr:ptr+length*self.dtype.itemsize].bitcast(self.dtype).to(None)
|
||||
return self.bin_t[ptr:ptr+length*self.dtype.itemsize].view(self.dtype)
|
||||
|
||||
# https://docs.nvidia.com/megatron-core/developer-guide/latest/api-guide/datasets.html
|
||||
class GPTDataset:
|
||||
|
|
@ -637,7 +637,7 @@ class GPTDataset:
|
|||
sample_parts.append(self.indexed_dataset.get(int(self.doc_idx[i]), offset=int(offset), length=length))
|
||||
|
||||
# concat all parts
|
||||
text = Tensor.cat(*sample_parts)
|
||||
text = np.concatenate(sample_parts, axis=0)
|
||||
|
||||
return text
|
||||
|
||||
|
|
@ -780,7 +780,8 @@ def get_llama3_dataset(samples:int, seqlen:int, base_dir:Path, seed:int=0, val:b
|
|||
def iterate_llama3_dataset(dataset:BlendedGPTDataset, bs:int):
|
||||
for b in range(math.ceil(dataset.samples / bs)):
|
||||
batch = [dataset.get(b * bs + i) for i in range(bs)]
|
||||
yield Tensor.stack(batch, dim=0)
|
||||
stacked = np.stack(batch, axis=0)
|
||||
yield Tensor(stacked, device="NPY")
|
||||
|
||||
def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True, small:bool=False):
|
||||
return iterate_llama3_dataset(get_llama3_dataset(samples, seqlen, base_dir, seed, val, small), bs)
|
||||
|
|
|
|||
|
|
@ -1390,6 +1390,7 @@ def train_llama3():
|
|||
|
||||
@TinyJit
|
||||
def minibatch(tokens:Tensor):
|
||||
tokens = tokens.to(None)
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
tokens = tokens.shard(device, 0)
|
||||
|
|
@ -1401,7 +1402,7 @@ def train_llama3():
|
|||
loss.backward()
|
||||
assert all(p.grad is g for p,g in zip(optim.params, grads))
|
||||
Tensor.realize(loss, *grads)
|
||||
return loss
|
||||
return loss.flatten().float().to("CPU")
|
||||
|
||||
@TinyJit
|
||||
def optim_step():
|
||||
|
|
@ -1428,11 +1429,12 @@ def train_llama3():
|
|||
lr = optim.lr
|
||||
Tensor.realize(lr, *grads)
|
||||
|
||||
return lr
|
||||
return lr.float().to("CPU")
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train(False)
|
||||
def eval_step(tokens:Tensor):
|
||||
tokens = tokens.to(None)
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
tokens = tokens.shard(device, 0)
|
||||
|
|
@ -1441,7 +1443,7 @@ def train_llama3():
|
|||
tokens = tokens.shard(device)
|
||||
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
return loss.flatten().float()
|
||||
return loss.flatten().float().to("CPU")
|
||||
|
||||
# ** data iters **
|
||||
def fake_data(bs, samples):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue