Skip to content

Commit 96f06df

Browse files
authored
Update distributionFocalLoss.m
1 parent ec16116 commit 96f06df

File tree

1 file changed

+16
-28
lines changed

1 file changed

+16
-28
lines changed

+helper/distributionFocalLoss.m

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
function boxes = distributionFocalLoss(boxes)
1+
function boxes = distributionFocalLoss(boxesInp)
22
% Distribution Focal Loss Module
33

4-
sz = size(boxes);
4+
sz = size(boxesInp);
55
c1=16; % pre-specified
66

77
% compute batch. channel and anchors
@@ -14,43 +14,31 @@
1414
anchors = sz(2);
1515

1616
% Reshape Operation
17-
boxes = permute(boxes,[2,1,3,4]); % 1 d qual like python
18-
boxes = reshape(boxes,anchors,c1,4,batch);
19-
boxes = permute(boxes,[2,1,3,4]);
17+
boxesInp = permute(boxesInp,[2,1,3,4]);
18+
boxesReshaped = reshape(boxesInp,anchors,c1,4,batch);
19+
boxesMapped = permute(boxesReshaped,[2,1,3,4]);
2020

2121
% Transpose Operation
22-
boxes = permute(boxes,[3,2,1,4]);
22+
boxesTrans = permute(boxesMapped,[3,2,1,4]);
23+
boxesTrans = extractdata(boxesTrans);
2324

2425
% softmax along the channel dimension
25-
boxes = softmax(dlarray(boxes,'SSC')); % produces a diff of 10^-3
26-
boxes = extractdata(boxes);
26+
boxesMax = softmax(dlarray(boxesTrans,'SSC'));
27+
boxesMax = extractdata(boxesMax);
2728

2829
% 1-d conv operation
2930
% Define weights
3031
weights = [0:c1-1];
31-
m = size(boxes,1);
32-
n = size(boxes,2);
32+
m = size(boxesMax,1);
33+
n = size(boxesMax,2);
3334
weights = reshape(repmat(weights,m*n,1),m,n,[]);
3435
% Conv operation
35-
boxes = boxes .* weights;
36-
boxes = sum(boxes,3); % diff of 10^-2 because of above diff
36+
boxesConv = boxesMax .* weights;
37+
boxesTotal = sum(boxesConv,3);
3738

3839
% Reshape Operation
39-
boxes = permute(boxes,[2,1,3,4]);
40-
boxes = reshape(boxes,anchors,4,batch);
41-
boxes = permute(boxes,[2,1,3,4]);
40+
boxesTmp = permute(boxesTotal,[2,1,3,4]);
41+
boxesTmpReshaped = reshape(boxesTmp,anchors,4,batch);
42+
boxes = permute(boxesTmpReshaped,[2,1,3,4]);
4243

4344
end
44-
45-
46-
47-
48-
49-
50-
51-
52-
53-
54-
55-
56-

0 commit comments

Comments
 (0)