mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix TRAIN_BEAM and Tensor.training for mlperf bert (#4525)
also hard coded bert model config instead of looking up a file
This commit is contained in:
parent
7fab8c9e17
commit
b00b6b16f0
3 changed files with 24 additions and 8 deletions
|
|
@ -194,7 +194,22 @@ def get_bert_qa_prediction(features, example, start_end_logits):
|
|||
return _get_final_text(tok_text, orig_text)
|
||||
return "empty"
|
||||
|
||||
def get_mlperf_bert_model(config_path:str):
|
||||
def get_mlperf_bert_config():
|
||||
return {
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"max_position_embeddings": 512,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30522
|
||||
}
|
||||
|
||||
def get_mlperf_bert_model():
|
||||
from extra.models import bert
|
||||
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
|
||||
|
||||
|
|
@ -204,8 +219,7 @@ def get_mlperf_bert_model(config_path:str):
|
|||
|
||||
from extra.models.bert import BertForMLPerf
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
config = get_mlperf_bert_config()
|
||||
return BertForMLPerf(
|
||||
config["hidden_size"],
|
||||
config["intermediate_size"],
|
||||
|
|
|
|||
|
|
@ -426,7 +426,7 @@ def train_bert():
|
|||
|
||||
Tensor.manual_seed(seed) # seed for weight initialization
|
||||
|
||||
model = get_mlperf_bert_model(BASEDIR / "bert_config.json")
|
||||
model = get_mlperf_bert_model()
|
||||
if init_ckpt: init_bert_from_checkpoint(model, init_ckpt)
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
|
|
@ -472,10 +472,10 @@ def train_bert():
|
|||
step_times = []
|
||||
# ** train loop **
|
||||
wc_start = time.perf_counter()
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
i, train_data = 0, get_data_bert(GPUS, train_it)
|
||||
while train_data is not None and i < train_steps and not achieved:
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
st = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss = train_step_bert(model, optimizer_group, scheduler, train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
||||
|
|
@ -525,7 +525,7 @@ def train_bert():
|
|||
Tensor.training = False
|
||||
BEAM.value = EVAL_BEAM
|
||||
|
||||
for _ in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
|
||||
for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK):
|
||||
eval_data = get_data_bert(GPUS, eval_it)
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
|
|
@ -541,6 +541,8 @@ def train_bert():
|
|||
et = time.time()
|
||||
eval_times.append(et - st)
|
||||
|
||||
if BENCHMARK and j == BENCHMARK: break
|
||||
|
||||
eval_step_bert.reset()
|
||||
Tensor.training = True
|
||||
total_lm_loss = sum(pair[0] for pair in eval_loss) / len(eval_loss)
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ if __name__ == "__main__":
|
|||
|
||||
Tensor.training = False
|
||||
|
||||
model = get_mlperf_bert_model(os.path.join(BASEDIR, "bert_config.json"))
|
||||
model = get_mlperf_bert_model()
|
||||
init_bert_from_checkpoint(model, INIT_CKPT_DIR) # Test the actual loading of the checkpoint
|
||||
|
||||
for _, x in get_state_dict(model).items():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue