Skip to content

Commit d783bf8

Browse files
committed
feat: add MethodGet(methodID)
- It will be used to query each method
1 parent 87730c8 commit d783bf8

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

internal/testutils/testutils.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ func MustExtractAPITokenFromEnv() string {
1212
}
1313
return apiToken
1414
}
15+
16+
17+
// ToStringPtr returns a pointer to the given string.
18+
func ToStringPtr(s string) *string {
19+
return &s
20+
}

method_get.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package paperswithcode_go
2+
3+
import (
4+
"encoding/json"
5+
"github.com/codingpot/paperswithcode-go/v2/models"
6+
)
7+
8+
// MethodGet returns a method in a paper.
9+
// See https://paperswithcode-client.readthedocs.io/en/latest/api/client.html#paperswithcode.client.PapersWithCodeClient.method_list
10+
func (c *Client) MethodGet(methodID string) (*models.Method, error) {
11+
url := c.baseURL + "/methods/" + methodID
12+
13+
response, err := c.httpClient.Get(url)
14+
if err != nil {
15+
return nil, err
16+
}
17+
18+
var result models.Method
19+
20+
err = json.NewDecoder(response.Body).Decode(&result)
21+
if err != nil {
22+
return nil, err
23+
}
24+
25+
return &result, nil
26+
}

method_get_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package paperswithcode_go
2+
3+
import (
4+
"github.com/codingpot/paperswithcode-go/v2/internal/testutils"
5+
"github.com/codingpot/paperswithcode-go/v2/models"
6+
"github.com/stretchr/testify/assert"
7+
"testing"
8+
)
9+
10+
func TestClient_MethodGet(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
methodID string
14+
want *models.Method
15+
wantErr bool
16+
}{
17+
{
18+
name: "With a correct methodID, it returns a method",
19+
methodID: "multi-head-attention",
20+
want: &models.Method{
21+
ID: "multi-head-attention",
22+
Name: "Multi-Head Attention",
23+
FullName: "Multi-Head Attention",
24+
Description: "**Multi-head Attention** is a module for attention mechanisms which runs through an attention mechanism several times in parallel. The independent attention outputs are then concatenated and linearly transformed into the expected dimension. Intuitively, multiple attention heads allows for attending to parts of the sequence differently (e.g. longer-term dependencies versus shorter-term dependencies). \r\n\r\n$$ \\text{MultiHead}\\left(\\textbf{Q}, \\textbf{K}, \\textbf{V}\\right) = \\left[\\text{head}\\_{1},\\dots,\\text{head}\\_{h}\\right]\\textbf{W}_{0}$$\r\n\r\n$$\\text{where} \\text{ head}\\_{i} = \\text{Attention} \\left(\\textbf{Q}\\textbf{W}\\_{i}^{Q}, \\textbf{K}\\textbf{W}\\_{i}^{K}, \\textbf{V}\\textbf{W}\\_{i}^{V} \\right) $$\r\n\r\nAbove $\\textbf{W}$ are all learnable parameter matrices.\r\n\r\nNote that [scaled dot-product attention](https://paperswithcode.com/method/scaled) is most commonly used in this module, although in principle it can be swapped out for other types of attention mechanism.\r\n\r\nSource: [Lilian Weng](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html#a-family-of-attention-mechanisms)",
25+
Paper: testutils.ToStringPtr("attention-is-all-you-need"),
26+
},
27+
wantErr: false,
28+
},
29+
}
30+
for _, tt := range tests {
31+
t.Run(tt.name, func(t *testing.T) {
32+
c := NewClient()
33+
got, err := c.MethodGet(tt.methodID)
34+
if tt.wantErr {
35+
assert.Error(t, err)
36+
} else {
37+
38+
assert.NoError(t, err)
39+
}
40+
assert.Equal(t, tt.want, got)
41+
})
42+
}
43+
}

0 commit comments

Comments
 (0)