Skip to content

[CI][Benchmarks] Automatically detect component versions in benchmark CI #18339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jun 2, 2025
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
eb4e123
Add L0 driver code for detecting compute-runtime versions
ianayl Apr 30, 2025
86263e1
instrument main with detect_versions
ianayl May 1, 2025
b21be5f
change wording, add frontend
ianayl May 1, 2025
0c77780
Fix bug
ianayl May 1, 2025
274be61
Hook up benchmark script to detect_versions
ianayl May 1, 2025
d8ab622
test changes in ci
ianayl May 2, 2025
4cb45ce
Fix bug
ianayl May 2, 2025
3648a35
Fix bug
ianayl May 2, 2025
7d927f7
Remove test code
ianayl May 5, 2025
a447867
Remove more checking for args
ianayl May 5, 2025
c67e052
Remove some odd choices
ianayl May 6, 2025
d611471
Merge branch 'sycl' of https://github.com/intel/llvm into ianayl/benc…
ianayl May 6, 2025
a07c211
add newline
ianayl May 6, 2025
3b0c996
Fix bug
ianayl May 6, 2025
233c062
darker format python
ianayl May 6, 2025
6915b3e
Apply clang-format
ianayl May 6, 2025
2bb4959
Add a way to predefine a cache beforehand
ianayl May 7, 2025
9fa6a1c
Add to workflow for testing
ianayl May 21, 2025
b651ca1
Remove debug messages
ianayl May 21, 2025
083cff0
Fix spelling
ianayl May 21, 2025
5b02887
apply clang-format
ianayl May 21, 2025
54bb593
Add logging for detect versions
ianayl May 22, 2025
c0949da
fix indent
ianayl May 22, 2025
a1c75cd
fix typo
ianayl May 23, 2025
633f51e
better signifier we are detecting versions
ianayl May 28, 2025
23201fe
More logging for detect versions
ianayl May 28, 2025
8a29481
More logging for detect versions
ianayl May 28, 2025
12cdbf5
I broke darker
ianayl May 28, 2025
2de4b33
apply formatting
ianayl May 29, 2025
2dd72e4
Merge branch 'sycl' of https://github.com/intel/llvm into ianayl/benc…
ianayl May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions devops/actions/run-tests/benchmark/action.yml
Original file line number Diff line number Diff line change
@@ -108,19 +108,6 @@ runs:
pip install --user --break-system-packages -r ./devops/scripts/benchmarks/requirements.txt
echo "-----"
# clang builds have git repo / commit hashes in their --version output,
# same goes for dpcpp. Obtain git repo / commit hash info this way:
# First line of --version is formatted 'clang version ... (<repo> <commit>)'
# thus we parse for (<repo> <commit>):
sycl_git_info="$(clang++ --version | head -n 1 | grep -oE '\([^ ]+ [a-f0-9]+\)$' | tr -d '()')"
if [ -z "$sycl_git_info" ]; then
echo "Error: Unable to deduce SYCL build source repo/commit: Are you sure dpcpp variable is in PATH?"
exit 1
fi
sycl_git_repo="$(printf "$sycl_git_info" | cut -d' ' -f1)"
sycl_git_commit="$(printf "$sycl_git_info" | cut -d' ' -f2)"
# By default, the benchmark scripts forceload level_zero
FORCELOAD_ADAPTER="${ONEAPI_DEVICE_SELECTOR%%:*}"
echo "Adapter: $FORCELOAD_ADAPTER"
@@ -138,6 +125,12 @@ runs:
# TODO accomodate for different GPUs and backends
SAVE_NAME="${SAVE_PREFIX}_PVC_${SAVE_SUFFIX}"
SAVE_TIMESTAMP="$(date -u +'%Y%m%d_%H%M%S')" # Timestamps are in UTC time
# Cache the compute_runtime version from dependencies.json, but perform a
# check with L0 version before using it: This value is not guaranteed to
# accurately reflect the current compute_runtime version used, as the
# docker images are built nightly.
export COMPUTE_RUNTIME_TAG_CACHE="$(cat ./devops/dependencies.json | jq -r .linux.compute_runtime.github_tag)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this envvar needed outside of this current script? if so we need to set it in a different way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is only needed in the main.py script ran immediately below this, which is why I haven't put it in GITHUB_ENV


sycl-ls
echo "-----"
@@ -152,8 +145,7 @@ runs:
--output-dir "./llvm-ci-perf-results/" \
--preset "$PRESET" \
--timestamp-override "$SAVE_TIMESTAMP" \
--github-repo "$sycl_git_repo" \
--git-commit "$sycl_git_commit"
--detect-version sycl,compute_runtime
echo "-----"
python3 ./devops/scripts/benchmarks/compare.py to_hist \
--name "$SAVE_NAME" \
31 changes: 7 additions & 24 deletions devops/scripts/benchmarks/compare.py
Original file line number Diff line number Diff line change
@@ -145,6 +145,9 @@ def validate_benchmark_result(result: BenchmarkRun) -> bool:

def reset_aggregate() -> dict:
return {
# TODO compare determine which command args have an
# impact on perf results, and do not compare arg results
# are incomparable
"command_args": set(test_run.command[1:]),
"aggregate": aggregator(starting_elements=[test_run.value]),
}
@@ -153,27 +156,7 @@ def reset_aggregate() -> dict:
if test_run.name not in average_aggregate:
average_aggregate[test_run.name] = reset_aggregate()
else:
# Check that we are comparing runs with the same cmd args:
if (
set(test_run.command[1:])
== average_aggregate[test_run.name]["command_args"]
):
average_aggregate[test_run.name]["aggregate"].add(
test_run.value
)
else:
# If the command args used between runs are different,
# discard old run data and prefer new command args
#
# This relies on the fact that paths from get_result_paths()
# is sorted from older to newer
print(
f"Warning: Command args for {test_run.name} from {result_path} is different from prior runs."
)
print(
"DISCARDING older data and OVERRIDING with data using new arg."
)
average_aggregate[test_run.name] = reset_aggregate()
average_aggregate[test_run.name]["aggregate"].add(test_run.value)

return {
name: BenchmarkHistoricAverage(
@@ -217,9 +200,9 @@ def halfway_round(value: int, n: int):
for test in target.results:
if test.name not in hist_avg:
continue
if hist_avg[test.name].command_args != set(test.command[1:]):
print(f"Warning: skipped {test.name} due to command args mismatch.")
continue
# TODO compare command args which have an impact on performance
# (i.e. ignore --save-name): if command results are incomparable,
# skip the result.

delta = 1 - (
test.value / hist_avg[test.name].value
28 changes: 22 additions & 6 deletions devops/scripts/benchmarks/history.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,8 @@
from utils.utils import run
from utils.validate import Validate

from utils.detect_versions import DetectVersion


class BenchmarkHistory:
runs = []
@@ -94,9 +96,13 @@ def git_info_from_path(path: Path) -> (str, str):
return git_hash, github_repo

if options.git_commit_override is None or options.github_repo_override is None:
git_hash, github_repo = git_info_from_path(
os.path.dirname(os.path.abspath(__file__))
)
if options.detect_versions.sycl:
print(f"Auto-detecting sycl version...")
github_repo, git_hash = DetectVersion.instance().get_dpcpp_git_info()
else:
git_hash, github_repo = git_info_from_path(
os.path.dirname(os.path.abspath(__file__))
)
else:
git_hash, github_repo = (
options.git_commit_override,
@@ -119,9 +125,19 @@ def git_info_from_path(path: Path) -> (str, str):
throw=ValueError("Illegal characters found in specified RUNNER_NAME."),
)

compute_runtime = (
options.compute_runtime_tag if options.build_compute_runtime else ""
)
compute_runtime = None
if options.build_compute_runtime:
compute_runtime = options.compute_runtime_tag
elif options.detect_versions.compute_runtime:
print(f"Auto-detecting compute_runtime version...")
detect_res = DetectVersion.instance()
compute_runtime = detect_res.get_compute_runtime_ver()
if detect_res.get_compute_runtime_ver_cached() is None:
print(
"Warning: Could not find compute_runtime version via github tags API."
)
else:
compute_runtime = "unknown"

return BenchmarkRun(
name=name,
49 changes: 46 additions & 3 deletions devops/scripts/benchmarks/main.py
Original file line number Diff line number Diff line change
@@ -19,11 +19,13 @@
from utils.utils import prepare_workdir
from utils.compute_runtime import *
from utils.validate import Validate
from utils.detect_versions import DetectVersion
from presets import enabled_suites, presets

import argparse
import re
import statistics
import os

# Update this if you are changing the layout of the results files
INTERNAL_WORKDIR_VERSION = "2.0"
@@ -506,7 +508,7 @@ def validate_and_parse_env_args(env_args):
type=lambda ts: Validate.timestamp(
ts,
throw=argparse.ArgumentTypeError(
"Specified timestamp not in YYYYMMDD_HHMMSS format."
"Specified timestamp not in YYYYMMDD_HHMMSS format"
),
),
help="Manually specify timestamp used in metadata",
@@ -517,7 +519,7 @@ def validate_and_parse_env_args(env_args):
type=lambda gh_repo: Validate.github_repo(
gh_repo,
throw=argparse.ArgumentTypeError(
"Specified github repo not in <owner>/<repo> format."
"Specified github repo not in <owner>/<repo> format"
),
),
help="Manually specify github repo metadata of component tested (e.g. SYCL, UMF)",
@@ -528,13 +530,32 @@ def validate_and_parse_env_args(env_args):
type=lambda commit: Validate.commit_hash(
commit,
throw=argparse.ArgumentTypeError(
"Specified commit is not a valid commit hash."
"Specified commit is not a valid commit hash"
),
),
help="Manually specify commit hash metadata of component tested (e.g. SYCL, UMF)",
default=options.git_commit_override,
)

parser.add_argument(
"--detect-version",
type=lambda components: Validate.on_re(
components,
r"[a-z_,]+",
throw=argparse.ArgumentTypeError(
"Specified --detect-version is not a comma-separated list"
),
),
help="Detect versions of components used: comma-separated list with choices from sycl,compute_runtime",
default=None,
)
parser.add_argument(
"--detect-version-cpp-path",
type=Path,
help="Location of detect_version.cpp used to query e.g. DPC++, L0",
default=None,
)

args = parser.parse_args()
additional_env_vars = validate_and_parse_env_args(args.env)

@@ -586,6 +607,28 @@ def validate_and_parse_env_args(env_args):
options.github_repo_override = args.github_repo
options.git_commit_override = args.git_commit

# Automatically detect versions:
if args.detect_version is not None:
detect_ver_path = args.detect_version_cpp_path
if detect_ver_path is None:
detect_ver_path = Path(
f"{os.path.dirname(__file__)}/utils/detect_versions.cpp"
)
if not detect_ver_path.is_file():
parser.error(
f"Unable to find detect_versions.cpp at {detect_ver_path}, please specify --detect-version-cpp-path"
)
elif not detect_ver_path.is_file():
parser.error(f"Specified --detect-version-cpp-path is not a valid file")

enabled_components = args.detect_version.split(",")
options.detect_versions.sycl = "sycl" in enabled_components
options.detect_versions.compute_runtime = (
"compute_runtime" in enabled_components
)

detect_res = DetectVersion.init(detect_ver_path)

benchmark_filter = re.compile(args.filter) if args.filter else None

main(
29 changes: 29 additions & 0 deletions devops/scripts/benchmarks/options.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,31 @@ class MarkdownSize(Enum):
FULL = "full"


@dataclass
class DetectVersionsOptions:
"""
Options for automatic version detection
"""

# Components to detect versions for:
sycl: bool = False
compute_runtime: bool = False
# umf: bool = False
# level_zero: bool = False

# Placeholder text, should automatic version detection fail: This text will
# only be used if automatic version detection for x component is explicitly
# specified.
not_found_placeholder = "unknown" # None

# TODO unauthenticated users only get 60 API calls per hour: this will not
# work if we enable benchmark CI in precommit.
compute_runtime_tag_api: str = (
"https://api.github.com/repos/intel/compute-runtime/tags"
)
# Max amount of api calls permitted on each run of the benchmark scripts
max_api_calls = 4

@dataclass
class Options:
workdir: str = None
@@ -64,5 +89,9 @@ class Options:
github_repo_override: str = None
git_commit_override: str = None

detect_versions: DetectVersionsOptions = field(
default_factory=DetectVersionsOptions
)


options = Options()
82 changes: 82 additions & 0 deletions devops/scripts/benchmarks/utils/detect_versions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <vector>

#include <level_zero/ze_api.h>

#define _assert(cond, msg) \
if (!(cond)) { \
std::cout << std::endl << "Error: " << msg << std::endl; \
exit(1); \
}

#define _success(res) res == ZE_RESULT_SUCCESS

std::string query_dpcpp_ver() { return std::string(__clang_version__); }

std::string query_l0_driver_ver() {
// Initialize L0 drivers:
ze_init_driver_type_desc_t driver_type = {};
driver_type.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
driver_type.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU;
driver_type.pNext = nullptr;

uint32_t driver_count = 0;
ze_result_t result = zeInitDrivers(&driver_count, nullptr, &driver_type);
_assert(_success(result), "Failed to initialize L0.");
_assert(driver_count > 0, "No L0 drivers available.");

std::vector<ze_driver_handle_t> drivers(driver_count);
result = zeInitDrivers(&driver_count, drivers.data(), &driver_type);
_assert(_success(result), "Could not fetch L0 drivers.");

// Check support for fetching driver version strings:
uint32_t ext_count = 0;
result = zeDriverGetExtensionProperties(drivers[0], &ext_count, nullptr);
_assert(_success(result), "Failed to obtain L0 extensions count.");
_assert(ext_count > 0, "No L0 extensions available.");

std::vector<ze_driver_extension_properties_t> extensions(ext_count);
result =
zeDriverGetExtensionProperties(drivers[0], &ext_count, extensions.data());
_assert(_success(result), "Failed to obtain L0 extensions.");
bool version_ext_support = false;
for (const auto &extension : extensions) {
// std::cout << extension.name << std::endl;
if (strcmp(extension.name, "ZE_intel_get_driver_version_string")) {
version_ext_support = true;
}
}
_assert(version_ext_support,
"ZE_intel_get_driver_version_string extension is not supported.");

// Fetch L0 driver version:
ze_result_t (*pfnGetDriverVersionFn)(ze_driver_handle_t, char *, size_t *);
result = zeDriverGetExtensionFunctionAddress(drivers[0],
"zeIntelGetDriverVersionString",
(void **)&pfnGetDriverVersionFn);
_assert(_success(result), "Failed to obtain GetDriverVersionString fn.");

size_t ver_str_len = 0;
result = pfnGetDriverVersionFn(drivers[0], nullptr, &ver_str_len);
_assert(_success(result), "Call to GetDriverVersionString failed.");

std::cout << "ver_str_len: " << ver_str_len << std::endl;
ver_str_len++; // ver_str_len does not account for '\0'
char *ver_str = (char *)calloc(ver_str_len, sizeof(char));
result = pfnGetDriverVersionFn(drivers[0], ver_str, &ver_str_len);
_assert(_success(result), "Failed to write driver version string.");

std::string res(ver_str);
free(ver_str);
return res;
}

int main() {
std::string dpcpp_ver = query_dpcpp_ver();
std::cout << "DPCPP_VER='" << dpcpp_ver << "'" << std::endl;

std::string l0_ver = query_l0_driver_ver();
std::cout << "L0_VER='" << l0_ver << "'" << std::endl;
}
250 changes: 250 additions & 0 deletions devops/scripts/benchmarks/utils/detect_versions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import os
import re
import sys
import json
import urllib
import tempfile
import subprocess
from urllib import request
from pathlib import Path
import argparse

if __name__ == "__main__":
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from options import options


def _get_patch_from_ver(ver: str) -> str:
"""Extract patch from a version string."""
# L0 version strings follows semver: major.minor.patch+optional
# compute-runtime version tags follow year.WW.patch.optional instead,
# but both follow a quasi-semver versioning where the patch, optional
# is still the same across both version string formats.
patch = re.sub(r"^\d+\.\d+\.", "", ver)
patch = re.sub(r"\+", ".", patch, count=1)
return patch


class DetectVersion:
_instance = None

def __init__(self):
raise RuntimeError("Use init() to init and instance() to get instead.")

@classmethod
def init(cls, detect_ver_path: Path, dpcpp_exec: str = "clang++"):
"""
Constructs the singleton instance for DetectVersion, and initializes by
building and run detect_version.cpp, which outputs:
- L0 driver version via ZE_intel_get_driver_version_string extension,
- DPC++ version via `__clang_version__` builtin.
Remind: DO NOT allow user input in args.
Parameters:
detect_ver_path (Path): Path to detect_version.cpp
dpcpp_exec (str): Name of DPC++ executable
"""
if cls._instance is not None:
return cls._instance

detect_ver_exe = tempfile.mktemp()
result = subprocess.run(
[dpcpp_exec, "-lze_loader", detect_ver_path, "-o", detect_ver_exe],
check=True,
env=os.environ,
)
result = subprocess.run(
[detect_ver_exe],
check=True,
text=True,
capture_output=True,
env=os.environ,
)
# Variables are printed to stdout, each var is on its own line
result_vars = result.stdout.strip().split("\n")

def get_var(var_name: str) -> str:
var_str = next(
filter(lambda v: re.match(f"^{var_name}='.*'", v), result_vars)
)
return var_str[len(f"{var_name}='") : -len("'")]

cls._instance = cls.__new__(cls)
cls._instance.l0_ver = get_var("L0_VER")
cls._instance.dpcpp_ver = get_var("DPCPP_VER")
cls._instance.dpcpp_exec = dpcpp_exec

# Populate the computer-runtime version string cache: Since API calls
# are expensive, we want to avoid API calls when possible, i.e.:
# - Avoid a second API call if compute_runtime_ver was already obtained
# - Avoid an API call altogether if the user provides a valid
# COMPUTE_RUNTIME_TAG_CACHE environment variable.
cls._instance.compute_runtime_ver_cache = None
l0_ver_patch = _get_patch_from_ver(get_var("L0_VER"))
env_cache_ver = os.getenv("COMPUTE_RUNTIME_TAG_CACHE", default="")
env_cache_patch = _get_patch_from_ver(env_cache_ver)
# L0 patch often gets padded with 0's: if the environment variable
# matches up with the prefix of the l0 version patch, the cache is
# indeed referring to the same version.
if env_cache_patch == l0_ver_patch[: len(env_cache_patch)]:
print(
f"Using compute_runtime tag from COMPUTE_RUNTIME_TAG_CACHE: {env_cache_ver}"
)
cls._instance.compute_runtime_ver_cache = env_cache_ver
else:
print(
f"Mismatch between COMPUTE_RUNTIME_TAG_CACHE {env_cache_ver} and patch reported by level_zero {get_var('L0_VER')}"
)

return cls._instance

@classmethod
def instance(cls):
"""
Returns singleton instance of DetectVersion if it has been initialized
via init(), otherwise return None.
"""
return cls._instance

def get_l0_ver(self) -> str:
"""
Returns the full L0 version string.
"""
return self.l0_ver

def get_dpcpp_ver(self) -> str:
"""
Returns the full DPC++ version / clang version string of DPC++ used.
"""
return self.dpcpp_ver

def get_dpcpp_git_info(self) -> [str, str]:
"""
Returns: (git_repo, commit_hash)
"""
# clang++ formats are in <clang ver> (<git url> <commit>): if this
# regex does not match, it is likely this is not upstream clang.
git_info_match = re.search(r"\(http.+ [0-9a-f]+\)", self.dpcpp_ver)
if git_info_match is None:
raise RuntimeError(
f"detect_version: Unable to obtain git info from {self.dpcpp_exec}, are you sure you are using DPC++?"
)
git_info = git_info_match.group(0)
return git_info[1:-1].split(" ")

def get_dpcpp_commit(self) -> str:
git_info = self.get_dpcpp_git_info()
if git_info is None:
return options.detect_versions.not_found_placeholder
return git_info[1]

def get_dpcpp_repo(self) -> str:
git_info = self.get_dpcpp_git_info()
if git_info is None:
return options.detect_versions.not_found_placeholder
return git_info[0]

def get_compute_runtime_ver_cached(self) -> str:
return self.compute_runtime_ver_cache

def get_compute_runtime_ver(self) -> str:
"""
Returns the compute-runtime version by deriving from l0 version.
"""
if self.compute_runtime_ver_cache is not None:
print(
f"Using cached compute-runtime tag {self.compute_runtime_ver_cache}..."
)
return self.compute_runtime_ver_cache

patch = _get_patch_from_ver(self.l0_ver)

# TODO unauthenticated users only get 60 API calls per hour: this will
# not work if we enable benchmark CI in precommit.
url = options.detect_versions.compute_runtime_tag_api

print(f"Fetching compute-runtime tag from {url}...")
try:
for _ in range(options.detect_versions.max_api_calls):
res = request.urlopen(url)
tags = [tag["name"] for tag in json.loads(res.read())]

for tag in tags:
tag_patch = _get_patch_from_ver(tag)
# compute-runtime's cmake files produces "optional" fields
# padded with 0's: this means e.g. L0 version string
# 1.6.32961.200000 could be either compute-runtime ver.
# 25.09.32961.2, 25.09.32961.20, or even 25.09.32961.200.
#
# Thus, we take the longest match. Since the github api
# provides tags from newer -> older, we take the first tag
# that matches as it would be the "longest" ver. to match.
if tag_patch == patch[: len(tag_patch)]:
self.compute_runtime_ver_cache = tag
return tag

def get_link_name(link: str) -> str:
rel_str = re.search(r'rel="\w+"', link).group(0)
return rel_str[len('rel="') : -len('"')]

def get_link_url(link: str) -> str:
return link[link.index("<") + 1 : link.index(">")]

links = {
get_link_name(link): get_link_url(link)
for link in res.getheader("Link").split(", ")
}

if "next" in links:
url = links["next"]
else:
break

except urllib.error.HTTPError as e:
print(f"HTTP error {e.code}: {e.read().decode('utf-8')}")

except urllib.error.URLError as e:
print(f"URL error: {e.reason}")

print(f"WARNING: unable to find compute-runtime version")
return options.detect_versions.not_found_placeholder


def main(components: [str]):
detect_res = DetectVersion.init(f"{os.path.dirname(__file__)}/detect_versions.cpp")

str2fn = {
"dpcpp_repo": detect_res.get_dpcpp_repo,
"dpcpp_commit": detect_res.get_dpcpp_commit,
"l0_ver": detect_res.get_l0_ver,
"compute_runtime_ver": detect_res.get_compute_runtime_ver,
}

def remove_undefined_components(component: str) -> bool:
if component not in str2fn:
print(f"# Warn: unknown component: {component}", file=sys.stderr)
return False
return True

components_clean = filter(remove_undefined_components, components)

for s in map(lambda c: f"{c.upper()}={str2fn[c]()}", components_clean):
print(s)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Get version information for specified components."
)
parser.add_argument(
"components",
type=str,
help="""
Comma-separated list of components to get version information for.
Valid options: dpcpp_repo,dpcpp_commit,l0_ver,compute_runtime_ver
""",
)
args = parser.parse_args()

main(map(lambda c: c.strip(), args.components.split(",")))
44 changes: 22 additions & 22 deletions devops/scripts/benchmarks/utils/validate.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
import re

class Validate:
"""Static class containing methods for validating various fields"""

def validate_on_re(val: str, regex: re.Pattern, throw: Exception = None):
"""
Returns True if val is matched by pattern defined by regex, otherwise False.
If `throw` argument is not None: return val as-is if val matches regex,
otherwise raise error defined by throw.
"""
is_matching: bool = re.compile(regex).match(val) is not None

if throw is None:
return is_matching
elif not is_matching:
raise throw
else:
return val
@staticmethod
def on_re(val: str, regex: str, throw: Exception = None):
"""
Returns True if val is matched by pattern defined by regex, otherwise
False.
If `throw` argument is not None: return val as-is if val matches regex,
otherwise raise error defined by throw.
"""
is_matching: bool = re.compile(regex).match(val) is not None

class Validate:
"""Static class containing methods for validating various fields"""
if throw is None:
return is_matching
elif not is_matching:
raise throw
else:
return val

@staticmethod
def runner_name(runner_name: str, throw: Exception = None):
"""
Returns True if runner_name is clean (no illegal characters).
"""
return validate_on_re(runner_name, r"^[a-zA-Z0-9_]+$", throw=throw)
return Validate.on_re(runner_name, r"^[a-zA-Z0-9_]+$", throw=throw)

@staticmethod
def timestamp(t: str, throw: Exception = None):
@@ -36,7 +36,7 @@ def timestamp(t: str, throw: Exception = None):
If throw argument is specified: return t as-is if t is in aforementioned
format, otherwise raise error defined by throw.
"""
return validate_on_re(
return Validate.on_re(
t,
r"^\d{4}(0[1-9]|1[0-2])([0-2][0-9]|3[01])_([01][0-9]|2[0-3])[0-5][0-9][0-5][0-9]$",
throw=throw,
@@ -50,7 +50,7 @@ def github_repo(repo: str, throw: Exception = None):
If throw argument is specified: return repo as-is if repo is in
aforementioned format, otherwise raise error defined by throw.
"""
return validate_on_re(
return Validate.on_re(
re.sub(r"^https?://github.com/", "", repo),
r"^[a-zA-Z0-9_-]{1,39}/[a-zA-Z0-9_.-]{1,100}$",
throw=throw,
@@ -67,6 +67,6 @@ def commit_hash(commit: str, throw: Exception = None, trunc: int = 40):
"""
commit_re = r"^[a-f0-9]{7,40}$"
if throw is None:
return validate_on_re(commit, commit_re)
return Validate.on_re(commit, commit_re)
else:
return validate_on_re(commit, commit_re, throw=throw)[:trunc]
return Validate.on_re(commit, commit_re, throw=throw)[:trunc]