Skip to content

Commit 9e58638

Browse files
authored
Merge pull request #14 from codingpot/feat-add-method-get
feat: add MethodGet(methodID)
2 parents 725a9d7 + d783bf8 commit 9e58638

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

internal/testutils/testutils.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ func MustExtractAPITokenFromEnv() string {
1212
}
1313
return apiToken
1414
}
15+
16+
17+
// ToStringPtr returns a pointer to the given string.
18+
func ToStringPtr(s string) *string {
19+
return &s
20+
}

method_get.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package paperswithcode_go
2+
3+
import (
4+
"encoding/json"
5+
"github.com/codingpot/paperswithcode-go/v2/models"
6+
)
7+
8+
// MethodGet returns a method in a paper.
9+
// See https://paperswithcode-client.readthedocs.io/en/latest/api/client.html#paperswithcode.client.PapersWithCodeClient.method_list
10+
func (c *Client) MethodGet(methodID string) (*models.Method, error) {
11+
url := c.baseURL + "/methods/" + methodID
12+
13+
response, err := c.httpClient.Get(url)
14+
if err != nil {
15+
return nil, err
16+
}
17+
18+
var result models.Method
19+
20+
err = json.NewDecoder(response.Body).Decode(&result)
21+
if err != nil {
22+
return nil, err
23+
}
24+
25+
return &result, nil
26+
}

method_get_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package paperswithcode_go
2+
3+
import (
4+
"github.com/codingpot/paperswithcode-go/v2/internal/testutils"
5+
"github.com/codingpot/paperswithcode-go/v2/models"
6+
"github.com/stretchr/testify/assert"
7+
"testing"
8+
)
9+
10+
func TestClient_MethodGet(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
methodID string
14+
want *models.Method
15+
wantErr bool
16+
}{
17+
{
18+
name: "With a correct methodID, it returns a method",
19+
methodID: "multi-head-attention",
20+
want: &models.Method{
21+
ID: "multi-head-attention",
22+
Name: "Multi-Head Attention",
23+
FullName: "Multi-Head Attention",
24+
Description: "**Multi-head Attention** is a module for attention mechanisms which runs through an attention mechanism several times in parallel. The independent attention outputs are then concatenated and linearly transformed into the expected dimension. Intuitively, multiple attention heads allows for attending to parts of the sequence differently (e.g. longer-term dependencies versus shorter-term dependencies). \r\n\r\n$$ \\text{MultiHead}\\left(\\textbf{Q}, \\textbf{K}, \\textbf{V}\\right) = \\left[\\text{head}\\_{1},\\dots,\\text{head}\\_{h}\\right]\\textbf{W}_{0}$$\r\n\r\n$$\\text{where} \\text{ head}\\_{i} = \\text{Attention} \\left(\\textbf{Q}\\textbf{W}\\_{i}^{Q}, \\textbf{K}\\textbf{W}\\_{i}^{K}, \\textbf{V}\\textbf{W}\\_{i}^{V} \\right) $$\r\n\r\nAbove $\\textbf{W}$ are all learnable parameter matrices.\r\n\r\nNote that [scaled dot-product attention](https://paperswithcode.com/method/scaled) is most commonly used in this module, although in principle it can be swapped out for other types of attention mechanism.\r\n\r\nSource: [Lilian Weng](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html#a-family-of-attention-mechanisms)",
25+
Paper: testutils.ToStringPtr("attention-is-all-you-need"),
26+
},
27+
wantErr: false,
28+
},
29+
}
30+
for _, tt := range tests {
31+
t.Run(tt.name, func(t *testing.T) {
32+
c := NewClient()
33+
got, err := c.MethodGet(tt.methodID)
34+
if tt.wantErr {
35+
assert.Error(t, err)
36+
} else {
37+
38+
assert.NoError(t, err)
39+
}
40+
assert.Equal(t, tt.want, got)
41+
})
42+
}
43+
}

0 commit comments

Comments
 (0)