fixing transformer training bug (#9877)

This commit is contained in:
Nishant Rajadhyaksha 2025-04-13 16:34:20 -07:00 committed by GitHub
commit 32ed128598
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -17,7 +17,7 @@ def make_dataset():
random.shuffle(ds)
ds = np.array(ds).astype(np.float32)
ds_X = ds[:, 0:6]
ds_Y = np.copy(ds[:, 1:])
ds_Y = np.copy(ds[:, 1:]).astype(np.int32)
ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:]
ds_Y_train, ds_Y_test = ds_Y[0:8000], ds_Y[8000:]
return ds_X_train, ds_Y_train, ds_X_test, ds_Y_test