Flutter 에서 yolov8 사용하기 - 2

이제 생성된 모델을 가져왔으니

모델에서 값을  추출하는 nms 함수를 살펴보아야 한다.

3차원 배열로 들어온 값에서 각각 맞는 값을 추출해야 한다.

이 부분은 텐서에서 추출한 값으로 설정되도록 함수를 조정하였다.

소스는 다음과 같으니 참고하기 바란다.

(List<int>, List<List<double>>, List<double>) nms(
List<List<double>> rawOutput, int count,
{double confidenceThreshold = 0.1, double iouThreshold = 0.4}) {
List<int> bestClasses = [];
List<double> bestScores = [];
List<int> boxesToSave = [];
for (int i = 0; i < rawOutput.shape[1]; i++) {
double bestScore = 0;
int bestCls = -1;
for (int j = 4; j < count; j++) {
double clsScore = rawOutput[j][i];
if (clsScore > bestScore) {
bestScore = clsScore;
bestCls = j - 4;
}
}
if (bestScore > confidenceThreshold) {
bestClasses.add(bestCls);
bestScores.add(bestScore);
boxesToSave.add(i);
}
}
List<List<double>> candidateBoxes = [];
for (var index in boxesToSave) {
List<double> savedBox = [];
for (int i = 0; i < 4; i++) {
savedBox.add(rawOutput[i][index]);
}
candidateBoxes.add(savedBox);
}
var sortedBestScores = List.from(bestScores);
sortedBestScores.sort((a, b) => -a.compareTo(b));
List<int> argSortList =
sortedBestScores.map((e) => bestScores.indexOf(e)).toList();
List<int> sortedBestClasses = [];
List<List<double>> sortedCandidateBoxes = [];
for (var index in argSortList) {
sortedBestClasses.add(bestClasses[index]);
sortedCandidateBoxes.add(candidateBoxes[index]);
}
List<List<double>> finalBboxes = [];
List<double> finalScores = [];
List<int> finalClasses = [];
while (sortedCandidateBoxes.isNotEmpty) {
var bbox1xywh = sortedCandidateBoxes.removeAt(0);
finalBboxes.add(bbox1xywh);
var bbox1xyxy = xywh2xyxy(bbox1xywh);
finalScores.add(sortedBestScores.removeAt(0));
var class1 = sortedBestClasses.removeAt(0);
finalClasses.add(class1);
List<int> indexesToRemove = [];
for (int i = 0; i < sortedCandidateBoxes.length; i++) {
if (class1 == sortedBestClasses[i]) {
if (computeIou(bbox1xyxy, xywh2xyxy(sortedCandidateBoxes[i])) >
iouThreshold) {
indexesToRemove.add(i);
}
}
}
for (var index in indexesToRemove.reversed) {
sortedCandidateBoxes.removeAt(index);
sortedBestClasses.removeAt(index);
sortedBestScores.removeAt(index);
}
}
return (finalClasses, finalBboxes, finalScores);
}
List<double> xywh2xyxy(List<double> bbox) {
double halfWidth = bbox[2] / 2;
double halfHeight = bbox[3] / 2;
return [
bbox[0] - halfWidth,
bbox[1] - halfHeight,
bbox[0] + halfWidth,
bbox[1] + halfHeight,
];
}
double computeIou(List<double> bbox1, List<double> bbox2) {
assert(bbox1[0] < bbox1[2]);
assert(bbox1[1] < bbox1[3]);
assert(bbox2[0] < bbox2[2]);
assert(bbox2[1] < bbox2[3]);
double xLeft = max(bbox1[0], bbox2[0]);
double yTop = max(bbox1[1], bbox2[1]);
double xRight = min(bbox1[2], bbox2[2]);
double yBottom = min(bbox1[3], bbox2[3]);
if (xRight < xLeft || yBottom < yTop) {
return 0;
}
double intersectionArea = (xRight - xLeft) * (yBottom - yTop);
double bbox1Area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]);
double bbox2Area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]);
double iou = intersectionArea / (bbox1Area + bbox2Area - intersectionArea);
assert(iou >= 0 && iou <= 1);
return iou;
}
view raw nms.dart hosted with ❤ by GitHub

이 함수에서 주목할 점은 텐서로 전달 받은 값에서 각 항목을 지정하여 추출하는 방법이다.

댓글

이 블로그의 인기 게시물

한글 2010 에서 Ctrl + F10 누르면 특수문자 안뜰 때

맥 화면이 안나올때 조치방법

아이폰에서 RFID 사용하는 방법