mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fixing transformer training bug (#9877)
This commit is contained in:
parent
bd5939514d
commit
32ed128598
1 changed files with 1 additions and 1 deletions
|
|
@ -17,7 +17,7 @@ def make_dataset():
|
|||
random.shuffle(ds)
|
||||
ds = np.array(ds).astype(np.float32)
|
||||
ds_X = ds[:, 0:6]
|
||||
ds_Y = np.copy(ds[:, 1:])
|
||||
ds_Y = np.copy(ds[:, 1:]).astype(np.int32)
|
||||
ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:]
|
||||
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]
|
||||
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue