mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-06-24 02:14:17 +00:00
fix bugs at examples/yolov3.py (#11614)
* Update load_weight. Give valid model url * Fix bug in iou function
This commit is contained in:
parent
0c97d6de1b
commit
ca7a641442
1 changed files with 9 additions and 9 deletions
|
|
@ -71,8 +71,8 @@ def bbox_iou(box1, box2):
|
|||
# get the coordinates of the intersection rectangle
|
||||
inter_rect_x1 = np.maximum(b1_x1, b2_x1)
|
||||
inter_rect_y1 = np.maximum(b1_y1, b2_y1)
|
||||
inter_rect_x2 = np.maximum(b1_x2, b2_x2)
|
||||
inter_rect_y2 = np.maximum(b1_y2, b2_y2)
|
||||
inter_rect_x2 = np.minimum(b1_x2, b2_x2)
|
||||
inter_rect_y2 = np.minimum(b1_y2, b2_y2)
|
||||
#Intersection area
|
||||
inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, 99999) * np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, 99999)
|
||||
#Union Area
|
||||
|
|
@ -297,13 +297,13 @@ class Darknet:
|
|||
# Get the number of weights of batchnorm
|
||||
num_bn_biases = math.prod(bn.bias.shape)
|
||||
# Load weights
|
||||
bn_biases = Tensor(weights[ptr:ptr + num_bn_biases])
|
||||
bn_biases = Tensor(weights[ptr:ptr + num_bn_biases].astype(np.float32))
|
||||
ptr += num_bn_biases
|
||||
bn_weights = Tensor(weights[ptr:ptr+num_bn_biases])
|
||||
bn_weights = Tensor(weights[ptr:ptr+num_bn_biases].astype(np.float32))
|
||||
ptr += num_bn_biases
|
||||
bn_running_mean = Tensor(weights[ptr:ptr+num_bn_biases])
|
||||
bn_running_mean = Tensor(weights[ptr:ptr+num_bn_biases].astype(np.float32))
|
||||
ptr += num_bn_biases
|
||||
bn_running_var = Tensor(weights[ptr:ptr+num_bn_biases])
|
||||
bn_running_var = Tensor(weights[ptr:ptr+num_bn_biases].astype(np.float32))
|
||||
ptr += num_bn_biases
|
||||
# Cast the loaded weights into dims of model weights
|
||||
bn_biases = bn_biases.reshape(shape=tuple(bn.bias.shape))
|
||||
|
|
@ -319,7 +319,7 @@ class Darknet:
|
|||
# load biases of the conv layer
|
||||
num_biases = math.prod(conv.bias.shape)
|
||||
# Load weights
|
||||
conv_biases = Tensor(weights[ptr: ptr+num_biases])
|
||||
conv_biases = Tensor(weights[ptr: ptr+num_biases].astype(np.float32))
|
||||
ptr += num_biases
|
||||
# Reshape
|
||||
conv_biases = conv_biases.reshape(shape=tuple(conv.bias.shape))
|
||||
|
|
@ -327,7 +327,7 @@ class Darknet:
|
|||
conv.bias = conv_biases
|
||||
# Load weighys for conv layers
|
||||
num_weights = math.prod(conv.weight.shape)
|
||||
conv_weights = Tensor(weights[ptr:ptr+num_weights])
|
||||
conv_weights = Tensor(weights[ptr:ptr+num_weights].astype(np.float32))
|
||||
ptr += num_weights
|
||||
conv_weights = conv_weights.reshape(shape=tuple(conv.weight.shape))
|
||||
conv.weight = conv_weights
|
||||
|
|
@ -371,7 +371,7 @@ class Darknet:
|
|||
if __name__ == "__main__":
|
||||
model = Darknet(fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg').read_bytes())
|
||||
print("Loading weights file (237MB). This might take a while…")
|
||||
model.load_weights('https://pjreddie.com/media/files/yolov3.weights')
|
||||
model.load_weights('https://github.com/shadiakiki1986/yolov3.weights/releases/download/3.0.1/yolov3.weights')
|
||||
if len(sys.argv) > 1:
|
||||
url = sys.argv[1]
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue