Skip to content
27 changes: 18 additions & 9 deletions .github/scripts/therock_configure_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
from typing import Iterable, Optional, Mapping


def gha_set_output(vars: Mapping[str, str | Path]):
"""Sets values in a step's output parameters.

Expand All @@ -25,6 +26,7 @@ def gha_set_output(vars: Mapping[str, str | Path]):
with open(step_output_file, "a") as f:
f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items())


def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
"""Returns the paths of modified files relative to the base reference."""
try:
Expand All @@ -42,11 +44,13 @@ def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
file=sys.stderr,
)
return None



GITHUB_WORKFLOWS_CI_PATTERNS = [
"therock*",
]


def is_path_workflow_file_related_to_ci(path: str) -> bool:
return any(
fnmatch.fnmatch(path, ".github/workflows/" + pattern)
Expand All @@ -56,11 +60,13 @@ def is_path_workflow_file_related_to_ci(path: str) -> bool:
for pattern in GITHUB_WORKFLOWS_CI_PATTERNS
)


def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool:
if paths is None:
return False
return any(is_path_workflow_file_related_to_ci(p) for p in paths)


# Paths matching any of these patterns are considered to have no influence over
# build or test workflows so any related jobs can be skipped if all paths
# modified by a commit/PR match a pattern in this list.
Expand All @@ -70,23 +76,26 @@ def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> boo
"*.md",
"*.pre-commit-config.*",
"*LICENSE",
'Jenkinsfile',
'.github/ISSUE_TEMPLATE/*',
'.github/CODEOWNERS',
'.github/*.md',
'.github/dependabot.yml',
"Jenkinsfile",
".github/ISSUE_TEMPLATE/*",
".github/CODEOWNERS",
".github/*.md",
".github/dependabot.yml",
]


def is_path_skippable(path: str) -> bool:
"""Determines if a given relative path to a file matches any skippable patterns."""
return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS)


def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if at least one path is not in the skippable set."""
if paths is None:
return False
return any(not is_path_skippable(p) for p in paths)


def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if CI workflows should run given a list of modified paths."""

Expand Down Expand Up @@ -118,16 +127,16 @@ def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
)
return False


def main(args):
base_ref = args.get("base_ref")
modified_paths = get_modified_paths(base_ref)
print("modified_paths (max 200):", modified_paths[:200])
enable_jobs = should_ci_run_given_modified_paths(modified_paths)
output = {
'enable_therock_ci': json.dumps(enable_jobs)
}
output = {"enable_therock_ci": json.dumps(enable_jobs)}
gha_set_output(output)


if __name__ == "__main__":
args = {}
args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1")
Expand Down
15 changes: 8 additions & 7 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

html_theme_options = {"flavor": "list"}

with open('../CMakeLists.txt', encoding='utf-8') as f:
match = re.search(r'.*set\(version ([0-9.]+)[^0-9.]+', f.read())
with open("../CMakeLists.txt", encoding="utf-8") as f:
match = re.search(r".*set\(version ([0-9.]+)[^0-9.]+", f.read())
if not match:
raise ValueError("VERSION not found!")
version_number = match[1]
Expand All @@ -34,17 +34,18 @@
external_projects_current_project = "composable_kernel"

mathjax3_config = {
'tex': {
'macros': {
'diag': '\\operatorname{diag}',
"tex": {
"macros": {
"diag": "\\operatorname{diag}",
}
}
}

for sphinx_var in ROCmDocs.SPHINX_VARS:
globals()[sphinx_var] = getattr(docs_core, sphinx_var)

extensions += ['sphinxcontrib.bibtex']
bibtex_bibfiles = ['refs.bib']

extensions += ["sphinxcontrib.bibtex"] # noqa F821: `extensions` is injected elsewhere.
bibtex_bibfiles = ["refs.bib"]

cpp_id_attributes = ["__global__", "__device__", "__host__"]
2 changes: 1 addition & 1 deletion example/ck_tile/01_fmha/codegen/cmake_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation

GEN_DIR = "" # in Cmake, have to generate files in same folder
GEN_DIR = "" # in Cmake, have to generate files in same folder
130 changes: 62 additions & 68 deletions example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,35 @@
# generate kernel instances to speed up compilation

FWD_DTYPE_MAP = {
"fp32" : "FmhaFwdFp32",
"fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp32": "FmhaFwdFp32",
"fp16": "FmhaFwdFp16",
"bf16": "FmhaFwdBf16",
"fp8": "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16",
"fp8fp32": "FmhaFwdFp8Fp32"
"fp8fp32": "FmhaFwdFp8Fp32",
}

BWD_DTYPE_MAP = {
"fp32": "FmhaBwdFp32",
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
}
BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}

MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
"generic": "ck_tile::GenericAttentionMask",
"simplified": "ck_tile::SimplifiedGenericAttentionMask",
}

_MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
}

_MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
"generic" : "FmhaMasks::GenericMask"
"no": "FmhaMasks::NoMask",
"causal": "FmhaMasks::CausalMask",
"generic": "FmhaMasks::GenericMask",
}

def get_mask_map(mask : str):

def get_mask_map(mask: str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
Expand All @@ -43,18 +40,20 @@ def get_mask_map(mask : str):
assert False
return None


_MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic": "t.mask_type == mask_enum::window_generic",
}

_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
"s_no": "t.mask_type == mask_enum::no_mask",
"s_mask": "t.mask_type != mask_enum::no_mask",
}

def get_mask_check_map(mask : str):

def get_mask_check_map(mask: str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
Expand All @@ -63,76 +62,71 @@ def get_mask_check_map(mask : str):
assert False
return None


BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
"no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI",
}

# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
"no": "bias_enum::no_bias",
"bias": "bias_enum::elementwise_bias",
"alibi": "bias_enum::alibi",
}

DROPOUT_MAP = {
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
"no": "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32": "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval": "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16": "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval": "ck_tile::BlockDropoutBwd<true, false, true >",
}

DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"no": "t.has_dropout == false",
"dropout_wg32": "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval": "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16": "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval": "t.has_dropout == true && t.is_store_randval == true",
}

ROPE_MAP = {
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
"no": "ck_tile::RotaryEmbeddingEnum::NONE",
"inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED",
}

ROPE_CHECK_MAP = {
"no" : "rope_enum::none",
"inter" : "rope_enum::interleaved",
"half" : "rope_enum::half_rotated"
"no": "rope_enum::none",
"inter": "rope_enum::interleaved",
"half": "rope_enum::half_rotated",
}

MODE_MAP = {
"batch" : "false",
"group" : "true"
}
MODE_MAP = {"batch": "false", "group": "true"}

LAYOUT_MAP = {
"row" : "true",
"col" : "false"
}
LAYOUT_MAP = {"row": "true", "col": "false"}

PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
}

PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
}

BOOL_MAP = {
"t" : "true",
"f" : "false",
True : "true",
False : "false",
"t": "true",
"f": "false",
True: "true",
False: "false",
}
Loading
Loading