Skip to content

No ONNX support for scaled_dot_product_atten #1752

Answered by rwightman
drexalt asked this question in Q&A
Discussion options

You must be logged in to vote

@jturner116 hrmm, didn't notice the ONNX issue I don't understand why PyTorch always breaks things like this when they add new ops :/

Using the vit one as example, does it work if you do something like (not the is_tracing addition)?

        if self.fast_attn and not torch.jit.is_tracing():
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@drexalt
Comment options

@rwightman
Comment options

@rwightman
Comment options

@drexalt
Comment options

Answer selected by drexalt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants