|
29 | 29 |
|
30 | 30 | # pyre-strict
|
31 | 31 | import ctypes
|
| 32 | +import hashlib |
32 | 33 | import operator
|
33 | 34 | import typing
|
34 | 35 | from dataclasses import dataclass, field
|
@@ -104,9 +105,9 @@ class _ProgramState:
|
104 | 105 | # as index 0 in the constant_buffer is reserved.
|
105 | 106 | allocated_specs: List[TensorSpec] = field(default_factory=list)
|
106 | 107 | # Weights in any arbitrary graph_module only need to compare against weights from previously
|
107 |
| - # emitted graph modules, not any weights emitted from itself. set to len(allocated_specs) after |
108 |
| - # every method emission. |
109 |
| - cached_spec_list_length: int = 0 |
| 108 | + # emitted graph modules, not any weights emitted from itself. This should speed up the lookup, |
| 109 | + # from O(N) to O(1) |
| 110 | + cached_spec_hash_values: Dict[str, int] = field(default_factory=dict) |
110 | 111 | # The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder.
|
111 | 112 | constant_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")])
|
112 | 113 | # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
|
@@ -346,74 +347,34 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
|
346 | 347 | # For non-constant tensors, constant_buffer = 0.
|
347 | 348 | return EValue(make_tensor_value(0, allocation_info, spec))
|
348 | 349 |
|
349 |
| - def _get_buffer_idx(spec: TensorSpec, program_state: _ProgramState) -> int: |
350 |
| - """Determines where in the program state the constant buffer corresponding to spec is |
351 |
| - located. |
| 350 | + # Constant tensor. Reserve a buffer for the constant tensor. |
| 351 | + spec_array_type = ( |
| 352 | + ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes() |
| 353 | + ) |
352 | 354 |
|
353 |
| - Returns the index into the constant buffers list if this spec has been previously |
354 |
| - allocated, -1 if unseen before. O(N^2) as for every tensor we have to compare it against |
355 |
| - every tensor previously allocated. Could improve this to O(N) if we hashed the weights. |
356 |
| - """ |
357 |
| - for i in range(0, program_state.cached_spec_list_length): |
358 |
| - other_spec = program_state.allocated_specs[i] |
| 355 | + buffer_data = ( |
| 356 | + bytes( |
| 357 | + ctypes.cast( |
| 358 | + typing.cast(torch.UntypedStorage, spec.storage).data_ptr(), |
| 359 | + ctypes.POINTER(spec_array_type), |
| 360 | + ).contents |
| 361 | + ) |
| 362 | + if spec.allocated_memory != 0 |
| 363 | + else b"" |
| 364 | + ) |
359 | 365 |
|
360 |
| - # Check for an empty buffer, special cased to avoid nullptr in buffer check. |
361 |
| - if spec.allocated_memory == 0 and other_spec.allocated_memory == 0: |
362 |
| - return i + 1 |
| 366 | + hashed = hashlib.sha256(buffer_data).hexdigest() |
363 | 367 |
|
364 |
| - # compare meta data |
365 |
| - if ( |
366 |
| - spec.scalar_type == other_spec.scalar_type |
367 |
| - and spec.shape == other_spec.shape |
368 |
| - and spec.dim_order == other_spec.dim_order |
369 |
| - and typing.cast(torch.UntypedStorage, spec.storage).nbytes() |
370 |
| - == typing.cast(torch.UntypedStorage, other_spec.storage).nbytes() |
371 |
| - ): |
372 |
| - spec_array_type = ( |
373 |
| - ctypes.c_char |
374 |
| - * typing.cast(torch.UntypedStorage, spec.storage).nbytes() |
375 |
| - ) |
376 |
| - other_spec_array_type = ( |
377 |
| - ctypes.c_char |
378 |
| - * typing.cast(torch.UntypedStorage, other_spec.storage).nbytes() |
379 |
| - ) |
380 |
| - # compare data |
381 |
| - if bytes( |
382 |
| - ctypes.cast( |
383 |
| - typing.cast(torch.UntypedStorage, spec.storage).data_ptr(), |
384 |
| - ctypes.POINTER(spec_array_type), |
385 |
| - ).contents |
386 |
| - ) == bytes( |
387 |
| - ctypes.cast( |
388 |
| - typing.cast( |
389 |
| - torch.UntypedStorage, other_spec.storage |
390 |
| - ).data_ptr(), |
391 |
| - ctypes.POINTER(other_spec_array_type), |
392 |
| - ).contents |
393 |
| - ): |
394 |
| - return i + 1 # +1 because the first buffer location is reserved |
395 |
| - return -1 |
396 |
| - |
397 |
| - buffer_idx = _get_buffer_idx(spec, self.program_state) |
| 368 | + buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) |
398 | 369 |
|
399 | 370 | # Haven't seen this constant before
|
400 | 371 | if buffer_idx == -1:
|
401 |
| - if spec.allocated_memory == 0: |
402 |
| - buffer = Buffer(storage=b"") |
403 |
| - else: |
404 |
| - array_type = ( |
405 |
| - ctypes.c_char |
406 |
| - * typing.cast(torch.UntypedStorage, spec.storage).nbytes() |
407 |
| - ) |
408 |
| - spec_array = ctypes.cast( |
409 |
| - typing.cast(torch.UntypedStorage, spec.storage).data_ptr(), |
410 |
| - ctypes.POINTER(array_type), |
411 |
| - ).contents |
412 |
| - buffer = Buffer(storage=bytes(spec_array)) |
413 |
| - |
414 | 372 | # Update buffer_idx to point to the end of the list where we are adding the new buffer.
|
| 373 | + buffer = Buffer(storage=buffer_data) |
415 | 374 | buffer_idx = len(self.program_state.constant_buffer)
|
416 | 375 | self.program_state.allocated_specs.append(spec)
|
| 376 | + # +1 because the first buffer location is reserved |
| 377 | + self.program_state.cached_spec_hash_values[hashed] = buffer_idx |
417 | 378 | self.program_state.constant_buffer.append(buffer)
|
418 | 379 |
|
419 | 380 | # For constant tensors, allocation_info = None.
|
|
0 commit comments