Skip to content

Commit 8652f63

Browse files
committed
Update pytorch.js (#1061)
1 parent b85e1ca commit 8652f63

File tree

4 files changed

+161
-127
lines changed

4 files changed

+161
-127
lines changed

source/python.js

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6178,11 +6178,14 @@ python.Execution = class {
61786178
get annotation_str() {
61796179
return this._annotation_str;
61806180
}
6181-
equals(/* rhs */) {
6182-
throw new python.Error(`Not implemented '${this.kind()}'.`);
6181+
equals(rhs) {
6182+
return this.kind() === rhs.kind();
61836183
}
6184-
isSubtypeOf(/* rhs */) {
6185-
throw new python.Error(`Not implemented '${this.kind()}'.`);
6184+
isSubtypeOf(rhs) {
6185+
if (rhs.kind() === 'OptionalType') {
6186+
return rhs.getElementType().equals(this);
6187+
}
6188+
return false;
61866189
}
61876190
str() {
61886191
if (this._kind === 'VarType' && this._annotation_str) {
@@ -6258,6 +6261,9 @@ python.Execution = class {
62586261
getElementType() {
62596262
return this._elem;
62606263
}
6264+
equals(rhs) {
6265+
return this.kind() === rhs.kind() && this.getElementType().equals(rhs.getElementType());
6266+
}
62616267
str() {
62626268
return `${this.getElementType().str()}?`;
62636269
}
@@ -6381,6 +6387,9 @@ python.Execution = class {
63816387
torch.NoneType.value = torch.NoneType.value || new torch.NoneType();
63826388
return torch.NoneType.value;
63836389
}
6390+
equals(rhs) {
6391+
return this.kind() === rhs.kind();
6392+
}
63846393
str() {
63856394
return 'NoneType';
63866395
}
@@ -6451,7 +6460,7 @@ python.Execution = class {
64516460
return this.kind() === rhs.kind();
64526461
}
64536462
isSubtypeOf(rhs) {
6454-
return this.kind() === 'NumberType' || super.isSubtypeOf(rhs);
6463+
return rhs.kind() === 'NumberType' || rhs.kind() === 'FloatType' || super.isSubtypeOf(rhs);
64556464
}
64566465
str() {
64576466
return 'int';
@@ -10812,8 +10821,7 @@ python.Execution = class {
1081210821
return undefined;
1081310822
}
1081410823

10815-
target(expression, context, resolve) {
10816-
resolve = resolve === false ? false : true;
10824+
target(expression, context) {
1081710825
let current = expression;
1081810826
let path = [];
1081910827
for (;;) {
@@ -10836,7 +10844,7 @@ python.Execution = class {
1083610844
break;
1083710845
}
1083810846
}
10839-
if (!target && resolve) {
10847+
if (!target) {
1084010848
path.reverse();
1084110849
const name = path.join('.');
1084210850
const file = `${path.join('/')}.py`;

source/pytorch-metadata.json

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,33 @@
172172
{
173173
"name": "aten::__and__.int(int a, int b) -> int"
174174
},
175+
{
176+
"name": "aten::__contains__.Tensor(Dict(Tensor, t) dict, Tensor key) -> bool"
177+
},
178+
{
179+
"name": "aten::__contains__.bool(Dict(bool, t) dict, bool key) -> bool"
180+
},
181+
{
182+
"name": "aten::__contains__.complex(Dict(complex, t) dict, complex key) -> bool"
183+
},
184+
{
185+
"name": "aten::__contains__.float(Dict(float, t) dict, float key) -> bool"
186+
},
187+
{
188+
"name": "aten::__contains__.float_list(float[] l, float item) -> bool"
189+
},
190+
{
191+
"name": "aten::__contains__.int(Dict(int, t) dict, int key) -> bool"
192+
},
193+
{
194+
"name": "aten::__contains__.int_list(int[] l, int item) -> bool"
195+
},
196+
{
197+
"name": "aten::__contains__.str(Dict(str, t) dict, str key) -> bool"
198+
},
199+
{
200+
"name": "aten::__contains__.str_list(str[] l, str item) -> bool"
201+
},
175202
{
176203
"name": "aten::__getitem__.Dict_Tensor(Dict(Tensor, t) self, Tensor key) -> t(*)"
177204
},

0 commit comments

Comments
 (0)