@@ -36,6 +36,7 @@ limitations under the License.
3636#include " xla/tsl/profiler/utils/xplane_builder.h"
3737#include " xla/tsl/profiler/utils/xplane_schema.h"
3838#include " xla/tsl/profiler/utils/xplane_test_utils.h"
39+ #include " xla/tsl/profiler/utils/xplane_utils.h"
3940#include " xla/tsl/profiler/utils/xplane_visitor.h"
4041#include " tsl/profiler/protobuf/xplane.pb.h"
4142
@@ -661,6 +662,78 @@ TEST(DerivedTimelineTest, EnsureAllGpuEventsAreGrouped) {
661662 });
662663}
663664
665+ // Tests that the multi-threaded processing of Tensor Core planes works
666+ // correctly.
667+ TEST (DerivedTimelineTest, MultiThreadedTensorCorePlaneProcessing) {
668+ constexpr int kNumPlanes = 4 ;
669+ constexpr int kNumEvents = 10 ;
670+ constexpr int kEventDurationPs = 100 ;
671+
672+ XSpace space;
673+ tsl::profiler::GroupMetadataMap group_metadata_map;
674+
675+ // Create multiple Tensor Core planes, each with a line of unsorted events.
676+ for (int i = 0 ; i < kNumPlanes ; ++i) {
677+ XPlane* plane = tsl::profiler::GetOrCreateTpuXPlane (
678+ &space, /* device_ordinal=*/ i, " TPU V4" , 0 , 0 );
679+ XPlaneBuilder plane_builder (plane);
680+ auto line_builder = plane_builder.GetOrCreateLine (0 );
681+ const std::string tf_op_name = absl::StrCat (" MyOp:" , i);
682+
683+ for (int j = 0 ; j < kNumEvents ; ++j) {
684+ // Add events in reverse order to test sorting.
685+ int64_t offset = (kNumEvents - 1 - j) * kEventDurationPs * 2 ;
686+ CreateXEvent (&plane_builder, &line_builder, " kernel" , offset,
687+ kEventDurationPs , {{StatType::kTfOp , tf_op_name}});
688+ }
689+ }
690+
691+ // This will trigger the multi-threaded logic you added.
692+ GenerateDerivedTimeLines (group_metadata_map, &space);
693+
694+ // Verify that each plane was processed correctly.
695+ for (int i = 0 ; i < kNumPlanes ; ++i) {
696+ const std::string plane_name = absl::StrCat (" /device:TPU:" , i);
697+ const XPlane* plane = tsl::profiler::FindPlaneWithName (space, plane_name);
698+ ASSERT_NE (plane, nullptr );
699+ XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor (plane);
700+
701+ // 1. Verify that the events on the original line are now sorted.
702+ const XLine* original_line = nullptr ;
703+ for (const auto & line : plane->lines ()) {
704+ if (line.id () == 0 ) {
705+ original_line = &line;
706+ break ;
707+ }
708+ }
709+ ASSERT_NE (original_line, nullptr );
710+
711+ int64_t last_timestamp_ps = -1 ;
712+ for (const auto & event : original_line->events ()) {
713+ ASSERT_GE (event.offset_ps (), last_timestamp_ps);
714+ last_timestamp_ps = event.offset_ps ();
715+ }
716+ EXPECT_EQ (original_line->events_size (), kNumEvents );
717+
718+ // 2. Verify that DeriveLinesFromStats created the derived TF Op line.
719+ bool tf_op_line_found = false ;
720+ plane_visitor.ForEachLine ([&](const XLineVisitor& line_visitor) {
721+ if (line_visitor.Name () == tsl::profiler::kTensorFlowOpLineName ) {
722+ tf_op_line_found = true ;
723+ EXPECT_EQ (line_visitor.NumEvents (), 1 ); // Should be merged into one.
724+ line_visitor.ForEachEvent ([&](const XEventVisitor& event) {
725+ EXPECT_EQ (event.Name (), absl::StrCat (" MyOp:" , i));
726+ // Check the duration of the merged event.
727+ int64_t expected_duration =
728+ (kNumEvents - 1 ) * kEventDurationPs * 2 + kEventDurationPs ;
729+ EXPECT_EQ (event.DurationPs (), expected_duration);
730+ });
731+ }
732+ });
733+ EXPECT_TRUE (tf_op_line_found);
734+ }
735+ }
736+
664737} // namespace
665738} // namespace profiler
666739} // namespace tensorflow
0 commit comments