Skip to content

Conversation

skazakov1
Copy link
Contributor

PR implements host-side (and device-side) scalar support in SDPA primitive.
Addresses https://jira.devtools.intel.com/browse/MFDNN-13611
based on #3506

@skazakov1 skazakov1 requested review from a team as code owners September 9, 2025 16:58
@github-actions github-actions bot added platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch component:common labels Sep 9, 2025
@skazakov1 skazakov1 force-pushed the skazakov/sdpa-host-scalar-PR branch from e8944dc to 8206e84 Compare September 9, 2025 16:59
@@ -79,6 +79,7 @@ struct sdpa_desc_t : public op_desc_t {

memory_desc_t dst_desc;
memory_desc_t attn_mask_desc;
memory_desc_t scale_desc;
data_type_t scale_dt {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be removed.

And a new member should be reflected in serialization, operator== and hashing functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reworked, please review

@skazakov1 skazakov1 force-pushed the skazakov/sdpa-host-scalar-PR branch from 8206e84 to bd2c712 Compare September 9, 2025 18:08
@skazakov1
Copy link
Contributor Author

make test
set test_scope=NIGHTLY

@@ -589,7 +589,7 @@ void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) {
desc.vs_zero_points.serialize(sstream);
serialize(sstream, desc.dst_desc);
serialize(sstream, desc.attn_mask_desc);
sstream.append(desc.scale_dt);
sstream.append(desc.scale_desc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sstream.append(desc.scale_desc);
serialize(sstream, desc.scale_desc);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed

Copy link
Contributor

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Graph backend part looks good to me.

@skazakov1
Copy link
Contributor Author

make test
disable benchdnn_all
set test_scope=NIGHTLY
enable benchdnn_graph
disable test_device_cpu
enable test_device_gpu
enable arch_gpu_xe-hpc
enable arch_gpu_xe-hpg-atsm
enable arch_gpu_xe-hpg-dg2
enable arch_gpu_xe-lp
enable arch_gpu_xe-lpg
enable arch_gpu_xe-lpg+
enable arch_gpu_xe2-hpg-bmg
enable arch_gpu_xe2-lpg

@skazakov1
Copy link
Contributor Author

make test perf-gpu
set primitive=sdpa

#if WITH_HOST_SCALE
#if INVERT_SCALE
iscale = SCALES_TO_FLOAT(scale_val);
scale = native_recip(iscale);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we could remove this native_recip call from the kenrel since we can do it on the host side before the kernel is called. This removes a floating point division from each work-item which is pretty significant.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Umar! will do it in separate PR, as it does not relate to host-scalar directly.

@@ -32,6 +32,8 @@ using mdt = memory::data_type;
using dnnl::accumulation_mode;

enum class mask_type { no_mask, oneD, twoD, causal_br, causal_tl };
enum class scale_type { host_side, device_side };
constexpr scale_type default_scale_type = scale_type::device_side;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should define the data type of the scalar as well as the type of the memory descriptor so we can test other scale data types.

If we are only supporting one data type then we should add checks to micro.hpp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more Thanks!
As I understand, currently, we dont have an option to test other scale data type != p.dt.dt, So it looks like feature/improvement, let me do it in separate PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:common component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants