Skip to content

Commit 3e99930

Browse files
larryliu0820facebook-github-bot
authored andcommitted
hash constant buffer to improve deduplication performance (#1185)
Summary: Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #1185 Currently for each constant tensor we need to compare its buffer with all the emitted buffer before to see if there's a duplicate. This is bad especially if we have more than one ExecutionPlan in the program. This PR introduces a map from the hash value to buffer index and hopefully can largely reduce the deduplication time. Pull Request resolved: #1185 Reviewed By: tarun292, JacobSzwejbka Differential Revision: D51182655 Pulled By: larryliu0820 fbshipit-source-id: df63bb800238c2021ae2355ad7d2e8cc13313067
1 parent 47900c9 commit 3e99930

File tree

2 files changed

+23
-66
lines changed

2 files changed

+23
-66
lines changed

exir/emit/_emit_program.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,6 @@ def emit_program(
185185
emitter.run()
186186
plans.append(emitter.plan())
187187

188-
# update list length for future constant deduplication checks
189-
emitter.program_state.cached_spec_list_length = len(
190-
program_state.allocated_specs
191-
)
192188
debug_handle_map[name] = emitter.debug_handle_map
193189
method_to_delegate_debug_id_map[
194190
name

exir/emit/_emitter.py

Lines changed: 23 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
# pyre-strict
3131
import ctypes
32+
import hashlib
3233
import operator
3334
import typing
3435
from dataclasses import dataclass, field
@@ -104,9 +105,9 @@ class _ProgramState:
104105
# as index 0 in the constant_buffer is reserved.
105106
allocated_specs: List[TensorSpec] = field(default_factory=list)
106107
# Weights in any arbitrary graph_module only need to compare against weights from previously
107-
# emitted graph modules, not any weights emitted from itself. set to len(allocated_specs) after
108-
# every method emission.
109-
cached_spec_list_length: int = 0
108+
# emitted graph modules, not any weights emitted from itself. This should speed up the lookup,
109+
# from O(N) to O(1)
110+
cached_spec_hash_values: Dict[str, int] = field(default_factory=dict)
110111
# The 0 index is reserved to be pointed to by non-constant tensors, so add an empty placeholder.
111112
constant_buffer: List[Buffer] = field(default_factory=lambda: [Buffer(storage=b"")])
112113
# Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
@@ -346,74 +347,34 @@ def _tensor_spec_to_evalue(self, spec: TensorSpec) -> EValue:
346347
# For non-constant tensors, constant_buffer = 0.
347348
return EValue(make_tensor_value(0, allocation_info, spec))
348349

349-
def _get_buffer_idx(spec: TensorSpec, program_state: _ProgramState) -> int:
350-
"""Determines where in the program state the constant buffer corresponding to spec is
351-
located.
350+
# Constant tensor. Reserve a buffer for the constant tensor.
351+
spec_array_type = (
352+
ctypes.c_char * typing.cast(torch.UntypedStorage, spec.storage).nbytes()
353+
)
352354

353-
Returns the index into the constant buffers list if this spec has been previously
354-
allocated, -1 if unseen before. O(N^2) as for every tensor we have to compare it against
355-
every tensor previously allocated. Could improve this to O(N) if we hashed the weights.
356-
"""
357-
for i in range(0, program_state.cached_spec_list_length):
358-
other_spec = program_state.allocated_specs[i]
355+
buffer_data = (
356+
bytes(
357+
ctypes.cast(
358+
typing.cast(torch.UntypedStorage, spec.storage).data_ptr(),
359+
ctypes.POINTER(spec_array_type),
360+
).contents
361+
)
362+
if spec.allocated_memory != 0
363+
else b""
364+
)
359365

360-
# Check for an empty buffer, special cased to avoid nullptr in buffer check.
361-
if spec.allocated_memory == 0 and other_spec.allocated_memory == 0:
362-
return i + 1
366+
hashed = hashlib.sha256(buffer_data).hexdigest()
363367

364-
# compare meta data
365-
if (
366-
spec.scalar_type == other_spec.scalar_type
367-
and spec.shape == other_spec.shape
368-
and spec.dim_order == other_spec.dim_order
369-
and typing.cast(torch.UntypedStorage, spec.storage).nbytes()
370-
== typing.cast(torch.UntypedStorage, other_spec.storage).nbytes()
371-
):
372-
spec_array_type = (
373-
ctypes.c_char
374-
* typing.cast(torch.UntypedStorage, spec.storage).nbytes()
375-
)
376-
other_spec_array_type = (
377-
ctypes.c_char
378-
* typing.cast(torch.UntypedStorage, other_spec.storage).nbytes()
379-
)
380-
# compare data
381-
if bytes(
382-
ctypes.cast(
383-
typing.cast(torch.UntypedStorage, spec.storage).data_ptr(),
384-
ctypes.POINTER(spec_array_type),
385-
).contents
386-
) == bytes(
387-
ctypes.cast(
388-
typing.cast(
389-
torch.UntypedStorage, other_spec.storage
390-
).data_ptr(),
391-
ctypes.POINTER(other_spec_array_type),
392-
).contents
393-
):
394-
return i + 1 # +1 because the first buffer location is reserved
395-
return -1
396-
397-
buffer_idx = _get_buffer_idx(spec, self.program_state)
368+
buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1)
398369

399370
# Haven't seen this constant before
400371
if buffer_idx == -1:
401-
if spec.allocated_memory == 0:
402-
buffer = Buffer(storage=b"")
403-
else:
404-
array_type = (
405-
ctypes.c_char
406-
* typing.cast(torch.UntypedStorage, spec.storage).nbytes()
407-
)
408-
spec_array = ctypes.cast(
409-
typing.cast(torch.UntypedStorage, spec.storage).data_ptr(),
410-
ctypes.POINTER(array_type),
411-
).contents
412-
buffer = Buffer(storage=bytes(spec_array))
413-
414372
# Update buffer_idx to point to the end of the list where we are adding the new buffer.
373+
buffer = Buffer(storage=buffer_data)
415374
buffer_idx = len(self.program_state.constant_buffer)
416375
self.program_state.allocated_specs.append(spec)
376+
# +1 because the first buffer location is reserved
377+
self.program_state.cached_spec_hash_values[hashed] = buffer_idx
417378
self.program_state.constant_buffer.append(buffer)
418379

419380
# For constant tensors, allocation_info = None.

0 commit comments

Comments
 (0)