-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[GPU] Add host side scalars support in SDPA primitive #3909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
e8944dc
to
8206e84
Compare
src/common/sdpa_types.hpp
Outdated
@@ -79,6 +79,7 @@ struct sdpa_desc_t : public op_desc_t { | |||
|
|||
memory_desc_t dst_desc; | |||
memory_desc_t attn_mask_desc; | |||
memory_desc_t scale_desc; | |||
data_type_t scale_dt {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems to be removed.
And a new member should be reflected in serialization, operator== and hashing functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reworked, please review
8206e84
to
bd2c712
Compare
make test |
@@ -589,7 +589,7 @@ void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) { | |||
desc.vs_zero_points.serialize(sstream); | |||
serialize(sstream, desc.dst_desc); | |||
serialize(sstream, desc.attn_mask_desc); | |||
sstream.append(desc.scale_dt); | |||
sstream.append(desc.scale_desc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sstream.append(desc.scale_desc); | |
serialize(sstream, desc.scale_desc); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Graph backend part looks good to me.
make test |
make test perf-gpu |
#if WITH_HOST_SCALE | ||
#if INVERT_SCALE | ||
iscale = SCALES_TO_FLOAT(scale_val); | ||
scale = native_recip(iscale); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if we could remove this native_recip call from the kenrel since we can do it on the host side before the kernel is called. This removes a floating point division from each work-item which is pretty significant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Umar! will do it in separate PR, as it does not relate to host-scalar directly.
@@ -32,6 +32,8 @@ using mdt = memory::data_type; | |||
using dnnl::accumulation_mode; | |||
|
|||
enum class mask_type { no_mask, oneD, twoD, causal_br, causal_tl }; | |||
enum class scale_type { host_side, device_side }; | |||
constexpr scale_type default_scale_type = scale_type::device_side; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should define the data type of the scalar as well as the type of the memory descriptor so we can test other scale data types.
If we are only supporting one data type then we should add checks to micro.hpp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one more Thanks!
As I understand, currently, we dont have an option to test other scale data type != p.dt.dt, So it looks like feature/improvement, let me do it in separate PR.
PR implements host-side (and device-side) scalar support in SDPA primitive.
Addresses https://jira.devtools.intel.com/browse/MFDNN-13611
based on #3506