Flutter 에서 yolov8 사용하기 - 2
이제 생성된 모델을 가져왔으니
모델에서 값을 추출하는 nms 함수를 살펴보아야 한다.
3차원 배열로 들어온 값에서 각각 맞는 값을 추출해야 한다.
이 부분은 텐서에서 추출한 값으로 설정되도록 함수를 조정하였다.
소스는 다음과 같으니 참고하기 바란다.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(List<int>, List<List<double>>, List<double>) nms( | |
List<List<double>> rawOutput, int count, | |
{double confidenceThreshold = 0.1, double iouThreshold = 0.4}) { | |
List<int> bestClasses = []; | |
List<double> bestScores = []; | |
List<int> boxesToSave = []; | |
for (int i = 0; i < rawOutput.shape[1]; i++) { | |
double bestScore = 0; | |
int bestCls = -1; | |
for (int j = 4; j < count; j++) { | |
double clsScore = rawOutput[j][i]; | |
if (clsScore > bestScore) { | |
bestScore = clsScore; | |
bestCls = j - 4; | |
} | |
} | |
if (bestScore > confidenceThreshold) { | |
bestClasses.add(bestCls); | |
bestScores.add(bestScore); | |
boxesToSave.add(i); | |
} | |
} | |
List<List<double>> candidateBoxes = []; | |
for (var index in boxesToSave) { | |
List<double> savedBox = []; | |
for (int i = 0; i < 4; i++) { | |
savedBox.add(rawOutput[i][index]); | |
} | |
candidateBoxes.add(savedBox); | |
} | |
var sortedBestScores = List.from(bestScores); | |
sortedBestScores.sort((a, b) => -a.compareTo(b)); | |
List<int> argSortList = | |
sortedBestScores.map((e) => bestScores.indexOf(e)).toList(); | |
List<int> sortedBestClasses = []; | |
List<List<double>> sortedCandidateBoxes = []; | |
for (var index in argSortList) { | |
sortedBestClasses.add(bestClasses[index]); | |
sortedCandidateBoxes.add(candidateBoxes[index]); | |
} | |
List<List<double>> finalBboxes = []; | |
List<double> finalScores = []; | |
List<int> finalClasses = []; | |
while (sortedCandidateBoxes.isNotEmpty) { | |
var bbox1xywh = sortedCandidateBoxes.removeAt(0); | |
finalBboxes.add(bbox1xywh); | |
var bbox1xyxy = xywh2xyxy(bbox1xywh); | |
finalScores.add(sortedBestScores.removeAt(0)); | |
var class1 = sortedBestClasses.removeAt(0); | |
finalClasses.add(class1); | |
List<int> indexesToRemove = []; | |
for (int i = 0; i < sortedCandidateBoxes.length; i++) { | |
if (class1 == sortedBestClasses[i]) { | |
if (computeIou(bbox1xyxy, xywh2xyxy(sortedCandidateBoxes[i])) > | |
iouThreshold) { | |
indexesToRemove.add(i); | |
} | |
} | |
} | |
for (var index in indexesToRemove.reversed) { | |
sortedCandidateBoxes.removeAt(index); | |
sortedBestClasses.removeAt(index); | |
sortedBestScores.removeAt(index); | |
} | |
} | |
return (finalClasses, finalBboxes, finalScores); | |
} | |
List<double> xywh2xyxy(List<double> bbox) { | |
double halfWidth = bbox[2] / 2; | |
double halfHeight = bbox[3] / 2; | |
return [ | |
bbox[0] - halfWidth, | |
bbox[1] - halfHeight, | |
bbox[0] + halfWidth, | |
bbox[1] + halfHeight, | |
]; | |
} | |
double computeIou(List<double> bbox1, List<double> bbox2) { | |
assert(bbox1[0] < bbox1[2]); | |
assert(bbox1[1] < bbox1[3]); | |
assert(bbox2[0] < bbox2[2]); | |
assert(bbox2[1] < bbox2[3]); | |
double xLeft = max(bbox1[0], bbox2[0]); | |
double yTop = max(bbox1[1], bbox2[1]); | |
double xRight = min(bbox1[2], bbox2[2]); | |
double yBottom = min(bbox1[3], bbox2[3]); | |
if (xRight < xLeft || yBottom < yTop) { | |
return 0; | |
} | |
double intersectionArea = (xRight - xLeft) * (yBottom - yTop); | |
double bbox1Area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]); | |
double bbox2Area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]); | |
double iou = intersectionArea / (bbox1Area + bbox2Area - intersectionArea); | |
assert(iou >= 0 && iou <= 1); | |
return iou; | |
} |
이 함수에서 주목할 점은 텐서로 전달 받은 값에서 각 항목을 지정하여 추출하는 방법이다.
댓글
댓글 쓰기