Conversation
vfdev-5
left a comment
There was a problem hiding this comment.
Thanks for the PR @rwtarpit !
I checked the code and thinking on how to make it a bit better.
We may need to figure out precisely as well on what kind of inputs we can compute top-k metric: binary, multi-class, multi-label? All of them or a set of them...
ignite/metrics/top_k.py
Outdated
| masked.scatter_(-1, top_indices, 1.0) | ||
| return (masked, y) | ||
|
|
||
| @reinit__is_reduced |
There was a problem hiding this comment.
I think, we do not need to decorate method with reinit__is_reduced and reinit__is_reduced as base metric will handle that.
ignite/metrics/top_k.py
Outdated
| for attr in self._base_metric._state_dict_all_req_keys: | ||
| setattr(self._base_metric, attr, self._states[k].get(attr, getattr(self._base_metric, attr))) |
There was a problem hiding this comment.
We can use Metric.state_dict and Metric.load_state_dict for that?
There was a problem hiding this comment.
i'll take a look at this
yes this would need some research before finalizing the metric. |
|
@vfdev-5 TopK woudn't support y_pred = torch.tensor([0.9, 0.3, 0.7]) # (N,)
y = torch.tensor([1, 0, 1 ])
# only one element per sampleMulticlass case: #k=2
y_pred = torch.tensor([[0.1, 0.6, 0.2, 0.4], # top2: class 1,3
[0.5, 0.2, 0.8, 0.3], # top2: class 2,0
[0.3, 0.1, 0.2, 0.9]]) # top2: class 3,0
y = torch.tensor([1, 2, 1]) # true classes
# accuracy@2 = 2/3 = 0.66
# samples: recall@2 = (1/1 + 1/1 + 0/1)/3 = 0.66 = accuracy@2
# micro: recall@2 = sum(TP)/sum(actual) = 2/3 = 0.66 = accuracy@2
# samples: precision@2 = (1/2 + 1/2 + 0/2)/3 = 0.333
# micro: precision@2 = sum(TP)/N*k = 2/6 = 0.333
#macro recall@2:
class 1: actual= sample 1 and 3, predicted = sample 1; recall@class1 = 1/2=0.5
class 2: actual= sample 2, predicted = sample 2; recall@class2 = 1/1=1.0
macro recall@2 = (0.5+1.0)/2 = 0.75
Multilabel case:
Similary |
can you confirm this too, i can then proceed with the implementation |
|
I have updated the draft PR with idea of keeping the
Advantages:
Disadvantages:
NOTE: i also checked how torch metrics implement TopK. they have topk logic for each metric where topk is passed as arg: currently i have kept it to precision/recall and added a small testcase to check sanity of the idea. this is a rough sketch of the idea and we still need to think of adding of other metrics and edge cases/ exceptions before moving forward and finalising this idea |
| """ | ||
| self._check_shape(output) | ||
| self._check_type(output) | ||
| if not getattr(self, "_skip_checks", False): |
There was a problem hiding this comment.
| if not getattr(self, "_skip_checks", False): | |
| if not self._skip_checks: |
I think using getattr might implies the attribute might not exist when this method runs. So it is better to initialize self._skip_checks = False
| import torch | ||
| from typing import Sequence | ||
|
|
||
| from ignite.metrics import Metric |
There was a problem hiding this comment.
| from ignite.metrics import Metric | |
| from ignite.metrics import Metric | |
| from ignite.metrics.metric import reinit__is_reduced |
|
|
||
| self._transform = transform | ||
| self._base_metric = base_metric | ||
| self._ks = sorted(top_k) if isinstance(top_k, list) else [top_k] |
There was a problem hiding this comment.
| self._ks = sorted(top_k) if isinstance(top_k, list) else [top_k] | |
| self._ks = sorted(top_k) if isinstance(top_k, list) else [top_k] | |
| identity = lambda x: x | |
| if base_metric.output_transform is not identity: | |
| import warnings | |
| warnings.warn( | |
| "base_metric's output_transform will never be called inside TopK. " | |
| "Pass output_transform to TopK directly instead.", | |
| UserWarning, | |
| ) |
TopK feeds the data directly to the base metric's update method, it completely bypasses the base metric's own output_transform. If a user passes a custom transform to the base metric, it will silently fail to run. So I think its better to add some warning.
| self._ks = sorted(top_k) if isinstance(top_k, list) else [top_k] | ||
| super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling) | ||
|
|
||
| def reset(self): |
There was a problem hiding this comment.
| def reset(self): | |
| @reinit__is_reduced | |
| def reset(self): |
Its better to add this decorator because its resets the flag when new epoch starts
There was a problem hiding this comment.
@vfdev-5 I just checked your previous comment that reinit__is_reduced flag automatically. If `self._base_metric.update() handles the flag reset for the wrapper, then we can safely ignore the decorator suggestions.
| self._base_metric.reset() | ||
| self._states = {k: self._base_metric.state_dict() for k in self._ks} | ||
|
|
||
| def update(self, output): |
There was a problem hiding this comment.
| def update(self, output): | |
| @reinit__is_reduced | |
| def update(self, output): |
I guess we need the decorator here too
|
|
||
| self._base_metric._skip_checks = False | ||
|
|
||
| def compute(self) -> list: |
There was a problem hiding this comment.
Distributed sync is missing here. The state tensors for each k inside self._states are never reduced across GPUs before compute() is called. I think we should add a docstring noting that distributed evaluation is not yet supported for TopK to avoid silent failures.
| def _precision_recall_topk_transform(output: Sequence[torch.Tensor], k: int): | ||
| """top_k transform for precision and recall""" | ||
| y_pred, y = output[0], output[1] | ||
| _, top_indices = torch.topk(y_pred, k=k, dim=-1) |
There was a problem hiding this comment.
| _, top_indices = torch.topk(y_pred, k=k, dim=-1) | |
| actual_k = min(k, y_pred.shape[-1]) | |
| _, top_indices = torch.topk(y_pred, k=actual_k, dim=-1) |
I think we need to guard against k exceeding the total number of items. If a user asks for k=10 but the test batch only has 5 items, torch.topk will throw a RuntimeError and leds to crash. So taking min prevents this.
| assert len(result) == 3 | ||
| assert pytest.approx(result[0], abs=1e-4) == 1.0 | ||
| assert pytest.approx(result[1], abs=1e-4) == 1.0 | ||
| assert pytest.approx(result[2], abs=1e-4) == 2 / 3 |
There was a problem hiding this comment.
| assert pytest.approx(result[2], abs=1e-4) == 2 / 3 | |
| assert pytest.approx(result[2], abs=1e-4) == 2 / 3 | |
| def test_top_k_k_exceeds_num_items(): | |
| """torch.topk crashes if k > number of items — verify guard works.""" | |
| from ignite.metrics import Precision | |
| from ignite.metrics.top_k import TopK | |
| import torch | |
| y_pred = torch.tensor([[0.9, 0.3, 0.8]]) # 3 items | |
| y_true = torch.tensor([[1, 0, 1]]) | |
| metric = TopK(Precision(average="samples", is_multilabel=True), top_k=[5]) # k=5 > 3 items | |
| metric.update((y_pred, y_true)) # should NOT crash | |
| result = metric.compute() | |
| assert len(result) == 1 |
Lets add this unit test to verify the min(k, total_items) guard works and prevents a RuntimeError !
| def __init__( | ||
| self, | ||
| base_metric: Metric, | ||
| top_k: int | list[int], | ||
| output_transform=lambda x: x, | ||
| device: str | torch.device = torch.device("cpu"), | ||
| skip_unrolling: bool = False, | ||
| ): |
There was a problem hiding this comment.
@TahaZahid05 seems like we have a similar idea/problem as SubgroupMetric from your PR. Can you check this implementation and give the feedback, thanks!
There was a problem hiding this comment.
@vfdev-5 i have made a general comment on the PR.
@rwtarpit you can check ignite/ignite/metrics/metrics_lambda.py Lines 148 to 170 in 6bfbdc4 |
|
@vfdev-5 Thanks for tagging. Part of my approach has already been discussed in #3568 from what I can see. I'll list down my approach again for explanation purposes: The main difference is in state management. This PR uses a single instance and swaps self._metrics = {k: copy.deepcopy(base_metric) for k in keys}This simplifies the Following is how the update looks under my implementation: for k in self._ks:
k_output = self._transform(output, k)
self._metrics[k].update(k_output)and return {k: self._metrics[k].compute() for k in self._ks}This is better in my opinion as you don't need to add the @rwtarpit mentioned in #3568 that they considered the use of |
|
thanks for the explaination @TahaZahid05 , i was bit reluctant too with this approach due to touching stable API of metrics. |
@rwtarpit I could not understand the usefulness of registry approach. If you can pitch again the idea with clear example, it would be helpful. |
|
so i tried to interleave many issues we were facing with TopK:
#TopK's update
class TopK:
# topk with branching logic
def update(base_metric, top_k,...):
if is_instance(base_metric, PrecisonRecall):
...
if is_instance(base_metric, Accuracy):
...
# with transform registry logic:
def update(base_metric, top_k,...):
transform = None
for metric_type, k_transform in self._output_transform_registry.items():
if isinstance(base_metric, metric_type):
transform = k_transform
...
# registering metrics to TopK
TopK.register(PrecisonRecall, _precision_recall_transform)
TopK.register(PrecisonRecall, _accuracy_transform)
...
|
Fixes #3568
Description:
As discussed in the issue thread, this is a draft PR for implementing a base structure for the
TopKwrapper class.the idea is to skip redundant checks on input data by skipping
_check_shapeand_check_typektimes.the wrapper only instantiates a single object of the base class and maintains states for each
kusing adict.for each batch, we check the data shape and type only once before updating the metric for each
kindividually. for this we use_skip_checksflag that can only be set from theTopKitself so it ensures metric function as usual when called without wrapper.TopKclass provides_wrap_prepare_outputthat handles top-k logic and transforms data accordingly. then metric can be calculated by base metric's update and compute as usual by the means ofstate dictand_state_dict_all_req_keysI have currently only implemented the wrapper for recall and precision currently to get an idea how it would look when extended to other metrics in future
Check list: