Skip to content

Conversation

@serach24
Copy link

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @serach24, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a powerful new frontend in JAX that allows users to execute custom device kernels, starting with NVIDIA PTX, directly within their JAX programs. This feature significantly extends JAX's capabilities by enabling low-level, device-specific optimizations while retaining compatibility with JAX's high-level transformations like JIT compilation, automatic differentiation, and automatic batching. The implementation includes robust mechanisms for defining kernel parameters, handling data layouts, and ensuring proper integration with JAX's execution model, making it easier to leverage specialized hardware features.

Highlights

  • New kernel_call Frontend: Introduces jax.experimental.device_kernels.kernel_call for executing device-specific kernels, initially supporting NVIDIA PTX.
  • JAX Integration: The new kernel_call primitive is fully integrated with JAX's core functionalities, including JIT compilation, automatic differentiation (JVP and transpose rules), and batching (vmap).
  • Batch Partitioning Support: PTX kernels are registered as batch partitionable by default, enabling automatic partitioning across leading dimensions for sharded data.
  • Lowering Mechanism: Provides build_device_kernel_lowering_function and kernel_lowering to construct MLIR lowering rules for custom device kernels.
  • Example and Tests: Includes a documentation example (docs/device_kernels/add.py) demonstrating PTX vector addition and a comprehensive test suite (tests/device_kernels_test.py) validating various aspects of the new functionality.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable new frontend for executing device-specific kernels, starting with PTX for NVIDIA GPUs. The implementation is well-structured, leveraging the existing FFI infrastructure effectively. The API design is clean, and the inclusion of comprehensive tests is commendable. I've identified a few minor issues, primarily concerning error messages, API consistency, and test code clarity, which are detailed in the comments.

Returns:
Result array(s) from kernel execution
"""
# Validate kernel_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The vectorized argument is deprecated in other JAX APIs like ffi_call. For consistency and to guide users towards the new API, a check should be added to kernel_call to raise an error if vectorized is used.

Suggested change
# Validate kernel_type
if not isinstance(vectorized, DeprecatedArg):
raise ValueError(
"The 'vectorized' argument of jax.experimental.device_kernels.kernel_call is deprecated. "
"Use 'vmap_method' instead.")
# Validate kernel_type

)
if not all(isinstance(idx, int) and 0 <= idx < len(args) + expected_num_outputs for idx in output_indices):
raise ValueError(
f"Output indices must be integers in range [0, {len(args)}), got {output_indices}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message for output_indices validation is incorrect. The upper bound of the valid range should be len(args) + expected_num_outputs, but the error message states len(args). This could be confusing for users.

Suggested change
f"Output indices must be integers in range [0, {len(args)}), got {output_indices}"
f"Output indices must be integers in range [0, {len(args) + expected_num_outputs}), got {output_indices}"

Comment on lines +90 to +91
# assert jnp.allclose(b, jnp.array([4.0, 5.0, 6.0], dtype=jnp.float32))
# assert jnp.allclose(expected, jnp.array([5.0, 7.0, 9.0], dtype=jnp.float32))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test contains commented-out assertions. They should be removed to improve test clarity and maintainability. An explicit check should be added to ensure input b is not modified by the kernel call, similar to the existing check for a.

Suggested change
# assert jnp.allclose(b, jnp.array([4.0, 5.0, 6.0], dtype=jnp.float32))
# assert jnp.allclose(expected, jnp.array([5.0, 7.0, 9.0], dtype=jnp.float32))
assert jnp.allclose(b, jnp.array([4.0, 5.0, 6.0], dtype=jnp.float32))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant