Skip to content

Commit b63cb72

Browse files
authored
[ET-VK] Fix caching mechanism to account for included files
Differential Revision: D78275585 Pull Request resolved: #12441
1 parent 1540659 commit b63cb72

File tree

1 file changed

+149
-54
lines changed

1 file changed

+149
-54
lines changed

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 149 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,9 @@ def escape(line: str) -> str:
545545
def preprocess(
546546
input_text: str, variables: Dict[str, Any], input_path: str = "codegen"
547547
) -> str:
548+
# Workaround to handle source files using \ to extend mecros to a new line
549+
input_text = re.sub(r"\\$", r"\\\\", input_text, flags=re.MULTILINE)
550+
548551
input_lines = input_text.splitlines()
549552
python_lines = []
550553

@@ -654,8 +657,8 @@ def addSrcAndYamlFiles(self, src_dir_paths: List[str]) -> None:
654657
for src_path in src_dir_paths:
655658
# Collect glsl source files
656659
src_files_list = glob.glob(
657-
os.path.join(src_path, "**", "*.glsl*"), recursive=True
658-
)
660+
os.path.join(src_path, "**", "*.[gh]lsl*"), recursive=True
661+
) + glob.glob(os.path.join(src_path, "**", "*.h"), recursive=True)
659662
for file in src_files_list:
660663
if len(file) > 1:
661664
self.src_files[extract_filename(file, keep_ext=False)] = file
@@ -851,47 +854,150 @@ def generateSPV( # noqa: C901
851854
cache_dir: Optional[str] = None,
852855
force_rebuild: bool = False,
853856
) -> Dict[str, str]:
854-
output_file_map = {}
857+
# The key of this dictionary is the full path to a generated source file. The
858+
# value is a tuple that contains 3 entries:
859+
#
860+
# 1. A bool indicationg if the file has changed since the last compilation; this
861+
# is determined by comparing against the cached version.
862+
# 2. List of other source files included by the generated file.
863+
gen_file_meta: Dict[str, Tuple[bool, List[str], str]] = {}
864+
865+
# Return value of the function mapping the abspath of compiled SPIR-V binaries
866+
# to the abspath of the generated GLSL file they were compiled from.
867+
spv_to_glsl_map: Dict[str, str] = {}
868+
869+
# Convert output_dir to absolute path
870+
assert os.path.exists(output_dir)
871+
output_dir = os.path.abspath(output_dir)
872+
873+
if cache_dir is not None:
874+
assert os.path.exists(cache_dir)
875+
876+
def get_glsl_includes(glsl_text):
877+
"""
878+
Parse GLSL text content and return a list of included files.
879+
880+
Args:
881+
glsl_text: String containing the GLSL file content to analyze
882+
883+
Returns:
884+
List of included file names (e.g., ["random.h"])
885+
"""
886+
includes = []
887+
for line in glsl_text.splitlines():
888+
# Look for #include directives with quoted filenames
889+
# Matches: #include "filename.h" or #include <filename.h>
890+
include_match = re.match(
891+
r'^\s*#include\s+[<"]([^>"]+)[>"]', line.strip()
892+
)
893+
if include_match:
894+
includes.append(include_match.group(1))
895+
896+
return includes
897+
898+
def file_has_changed(gen_file_path, cached_file_path):
899+
# If the file does not exist in the cache, then return True
900+
if not os.path.exists(cached_file_path):
901+
return True
902+
current_checksum = self.get_md5_checksum(gen_file_path)
903+
cached_checksum = self.get_md5_checksum(cached_file_path)
904+
return current_checksum != cached_checksum
905+
906+
def any_sources_changed(gen_file_path, output_dir):
907+
"""
908+
Given the path to a generated source file, check the gen_file_meta dict to
909+
determine if the ANY of the source files contributing to the compilation of
910+
this file were changed since the last successful compilation.
911+
"""
912+
gen_file_changed, includes_list = gen_file_meta[gen_file_path]
913+
any_changed = gen_file_changed
914+
for included_file in includes_list:
915+
included_file_path = os.path.join(output_dir, included_file)
916+
any_changed = any_changed or any_sources_changed(
917+
included_file_path, output_dir
918+
)
919+
920+
return any_changed
855921

856-
def generate_src_file(shader_paths_pair):
857-
# Extract components from the input tuple
858-
# name of .glsl, .glslh, or .h to be generated
922+
def generate_src_file(shader_paths_pair) -> Tuple[bool, List[str]]:
923+
"""
924+
Given an input tuple containing the following items:
925+
(src_file_name, (template_file_path, codegen_params))
926+
927+
This function generates src_file_name by processing
928+
template_file_path with the Python preprocessor using the
929+
parameters specified by codegen_params.
930+
931+
Then, it returns a tuple containing:
932+
1. The path of the generated source file
933+
2. A bool indicating if the generated source file has changed since the last
934+
compilation.
935+
3. A list of files included by the generated source file
936+
"""
937+
# name of .glsl, .glslh, or .h file to be generated
859938
src_file_name = shader_paths_pair[0]
860939
# path of template file used for codegen
861-
src_file_fullpath = shader_paths_pair[1][0]
940+
template_file_path = shader_paths_pair[1][0]
862941
# args to be used for codegen
863942
codegen_params = shader_paths_pair[1][1]
864943

865944
# Assume that generated files will have the same file extension as the
866945
# source template file.
867-
src_file_ext = extract_extension(src_file_fullpath)
868-
out_file_ext = src_file_ext
946+
out_file_ext = extract_extension(template_file_path)
869947

870948
# Construct generated file name
871949
gen_out_path = os.path.join(output_dir, f"{src_file_name}.{out_file_ext}")
950+
# Construct path of cached generated file
951+
cached_gen_out_path = os.path.join(
952+
cache_dir, f"{src_file_name}.{out_file_ext}"
953+
)
872954

873955
# Execute codegen to generate the output file
874-
with codecs.open(src_file_fullpath, "r", encoding="utf-8") as input_file:
956+
with codecs.open(template_file_path, "r", encoding="utf-8") as input_file:
875957
input_text = input_file.read()
876958
input_text = self.maybe_replace_u16vecn(input_text)
877959
output_text = preprocess(input_text, codegen_params)
878960

961+
included_files = get_glsl_includes(output_text)
962+
879963
with codecs.open(gen_out_path, "w", encoding="utf-8") as output_file:
880964
output_file.write(output_text)
881965

882-
def compile_spirv(shader_paths_pair):
883-
# Extract components from the input tuple
884-
# name of generated .glsl, .glslh, or .h
966+
file_changed = (
967+
file_has_changed(gen_out_path, cached_gen_out_path) or force_rebuild
968+
)
969+
970+
# Save the generated file to cache so it can be used for future checks
971+
if cache_dir is not None and file_changed:
972+
shutil.copyfile(gen_out_path, cached_gen_out_path)
973+
974+
return gen_out_path, file_changed, included_files
975+
976+
def compile_spirv(shader_paths_pair) -> Tuple[str, str]:
977+
"""
978+
Given an input tuple containing the following items:
979+
(src_file_name, (template_file_path, codegen_params))
980+
981+
Infer the path of the GLSL source file generated by generate_src_file and
982+
compile a SPIR-V binary from it. Returns the path of the compiled SPIR-V
983+
binary and the path of the source file used to compile it.
984+
985+
This function also utilizes a caching mechanism; if generate_src_file
986+
reported that the source file was unchanged since the last successful
987+
compilation, AND if the SPIR-V from the last successful compilation was
988+
stored in the cache, then directly use the cached SPIR-V without triggering
989+
a re-compilation.
990+
"""
991+
# name of generated .glsl, .glslh, or .h from generate_src_file
885992
src_file_name = shader_paths_pair[0]
886993
# path of template file used for codegen
887-
src_file_fullpath = shader_paths_pair[1][0]
994+
template_file_path = shader_paths_pair[1][0]
888995
# args used for codegen
889996
codegen_params = shader_paths_pair[1][1]
890997

891998
# Assume that generated files will have the same file extension as the
892999
# source template file.
893-
src_file_ext = extract_extension(src_file_fullpath)
894-
out_file_ext = src_file_ext
1000+
out_file_ext = extract_extension(template_file_path)
8951001

8961002
# Infer name of generated file (created by generate_src_file)
8971003
gen_out_path = os.path.join(output_dir, f"{src_file_name}.{out_file_ext}")
@@ -900,32 +1006,21 @@ def compile_spirv(shader_paths_pair):
9001006
if out_file_ext != "glsl":
9011007
return (None, gen_out_path)
9021008

903-
# Construct name of SPIR-V file to be compiled, if needed
1009+
# Validate that the source file actually exists
1010+
assert os.path.exists(gen_out_path) and gen_out_path in gen_file_meta
1011+
1012+
# Construct name of SPIR-V file to be compiled
9041013
spv_out_path = os.path.join(output_dir, f"{src_file_name}.spv")
9051014

9061015
if cache_dir is not None:
9071016
# Construct the file names of cached SPIR-V file to check if they exist
9081017
# in the cache.
909-
cached_gen_out_path = os.path.join(
910-
cache_dir, f"{src_file_name}.{out_file_ext}"
911-
)
9121018
cached_spv_out_path = os.path.join(cache_dir, f"{src_file_name}.spv")
9131019

914-
# Only use cached artifacts if all of the expected artifacts are present
915-
if (
916-
not force_rebuild
917-
and os.path.exists(cached_gen_out_path)
918-
and os.path.exists(cached_spv_out_path)
919-
):
920-
current_checksum = self.get_md5_checksum(gen_out_path)
921-
cached_checksum = self.get_md5_checksum(cached_gen_out_path)
922-
# If the cached generated GLSL file is the same as the current GLSL
923-
# generated file, then assume that the generated GLSL and SPIR-V
924-
# will not have changed. In that case, just copy over the GLSL and
925-
# SPIR-V files from the cache and return.
926-
if current_checksum == cached_checksum:
927-
shutil.copyfile(cached_spv_out_path, spv_out_path)
928-
return (spv_out_path, gen_out_path)
1020+
can_use_cached = not any_sources_changed(gen_out_path, output_dir)
1021+
if can_use_cached and os.path.exists(cached_spv_out_path):
1022+
shutil.copyfile(cached_spv_out_path, spv_out_path)
1023+
return (spv_out_path, gen_out_path)
9291024

9301025
vk_version = codegen_params.get("VK_VERSION", "1.1")
9311026
# Only proceed if a GLSL compiler was specified
@@ -938,10 +1033,8 @@ def compile_spirv(shader_paths_pair):
9381033
spv_out_path,
9391034
"--target-env=vulkan{}".format(vk_version),
9401035
"-Werror",
941-
] + [
942-
arg
943-
for src_dir_path in self.src_dir_paths
944-
for arg in ["-I", src_dir_path]
1036+
"-I",
1037+
output_dir,
9451038
]
9461039
cmd = cmd_base + self.glslc_flags
9471040

@@ -955,43 +1048,45 @@ def compile_spirv(shader_paths_pair):
9551048
try:
9561049
subprocess.run(cmd_no_opt, check=True, capture_output=True)
9571050
except subprocess.CalledProcessError as e_no_opt:
1051+
# Delete any existing cached SPIR-V file if it exists
1052+
if os.path.exists(cached_spv_out_path):
1053+
os.remove(cached_spv_out_path)
1054+
9581055
raise RuntimeError(
9591056
f"{err_msg_base} {e_no_opt.stderr}"
9601057
) from e_no_opt
9611058

9621059
else:
1060+
# Delete any existing cached SPIR-V file if it exists
1061+
if os.path.exists(cached_spv_out_path):
1062+
os.remove(cached_spv_out_path)
1063+
9631064
raise RuntimeError(f"{err_msg_base} {e.stderr}") from e
9641065

965-
# If compilation was successful, store the source GLSL file and the
966-
# compiled SPIR-V file in the cache for future comparison.
1066+
# If compilation was successful, store the compiled SPIR-V file in the
1067+
# cache for future use.
9671068
if cache_dir is not None:
968-
shutil.copyfile(gen_out_path, cached_gen_out_path)
9691069
shutil.copyfile(spv_out_path, cached_spv_out_path)
9701070

9711071
return (spv_out_path, gen_out_path)
9721072

9731073
# Run codegen serially to ensure that all .glsl, .glslh, and .h files are up to
9741074
# date before compilation
9751075
for generated_file_tuple in self.output_file_map.items():
976-
generate_src_file(generated_file_tuple)
1076+
gen_out_path, file_changed, include_list = generate_src_file(
1077+
generated_file_tuple
1078+
)
1079+
gen_file_meta[gen_out_path] = (file_changed, include_list)
9771080

9781081
# Parallelize SPIR-V compilation to optimize build time
9791082
with ThreadPool(os.cpu_count()) as pool:
9801083
for spv_out_path, glsl_out_path in pool.map(
9811084
compile_spirv, self.output_file_map.items()
9821085
):
983-
output_file_map[spv_out_path] = glsl_out_path
984-
985-
# Save all source GLSL files to the cache. Only do this at the very end since
986-
# multiple variants may use the same source file.
987-
if cache_dir is not None:
988-
for _, src_file_fullpath in self.src_files.items():
989-
cached_src_file = os.path.join(
990-
cache_dir, os.path.basename(src_file_fullpath) + ".t"
991-
)
992-
shutil.copyfile(src_file_fullpath, cached_src_file)
1086+
print(spv_to_glsl_map)
1087+
spv_to_glsl_map[spv_out_path] = glsl_out_path
9931088

994-
return output_file_map
1089+
return spv_to_glsl_map
9951090

9961091

9971092
##############################################

0 commit comments

Comments
 (0)