mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
add int64 as supported dtype from numpy (#699)
* add int64 as supported dtype from numpy Without this, examples/transformer.py didn't run. With this change it runs successfully. * Update helpers.py * Update transformer.py * Update training.py
This commit is contained in:
parent
f355b02987
commit
73bd0b217b
2 changed files with 2 additions and 2 deletions
|
|
@ -14,7 +14,7 @@ def make_dataset():
|
|||
s = i+j
|
||||
ds.append([i//10, i%10, j//10, j%10, s//100, (s//10)%10, s%10])
|
||||
random.shuffle(ds)
|
||||
ds = np.array(ds)
|
||||
ds = np.array(ds).astype(np.float32)
|
||||
ds_X = ds[:, 0:6]
|
||||
ds_Y = np.copy(ds[:, 1:])
|
||||
ds_X_train, ds_X_test = ds_X[0:8000], ds_X[8000:]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from tinygrad.helpers import getenv
|
|||
|
||||
def sparse_categorical_crossentropy(out, Y):
|
||||
num_classes = out.shape[-1]
|
||||
YY = Y.flatten()
|
||||
YY = Y.flatten().astype(np.int32)
|
||||
y = np.zeros((YY.shape[0], num_classes), np.float32)
|
||||
# correct loss for NLL, torch NLL loss returns one per row
|
||||
y[range(y.shape[0]),YY] = -1.0*num_classes
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue