mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
training param for batchnorm
This commit is contained in:
parent
21c78b9316
commit
d5d9cffe7c
1 changed files with 2 additions and 1 deletions
|
|
@ -16,6 +16,7 @@ NUM = int(os.getenv("NUM", 2))
|
|||
BS = int(os.getenv("BS", 8))
|
||||
CNT = int(os.getenv("CNT", 10))
|
||||
BACKWARD = int(os.getenv("BACKWARD", 0))
|
||||
TRAINING = int(os.getenv("TRAINING", 1))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"NUM:{NUM} BS:{BS} CNT:{CNT}")
|
||||
|
|
@ -23,7 +24,7 @@ if __name__ == "__main__":
|
|||
parameters = get_parameters(model)
|
||||
optimizer = optim.SGD(parameters, lr=0.001)
|
||||
|
||||
Tensor.training = True
|
||||
Tensor.training = TRAINING
|
||||
for i in trange(CNT):
|
||||
cpy = time.monotonic()
|
||||
x_train = Tensor.randn(BS, 3, 224, 224, requires_grad=False).realize()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue