Skip to content

Commit 6396415

Browse files
muditgokhale2copybara-github
authored andcommitted
Add multi-threading support to the processing of tensor core planes.
PiperOrigin-RevId: 815596292
1 parent daf41ac commit 6396415

File tree

3 files changed

+88
-3
lines changed

3 files changed

+88
-3
lines changed

xprof/utils/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ cc_test(
235235
"@xla//xla/tsl/profiler/utils:xplane_builder",
236236
"@xla//xla/tsl/profiler/utils:xplane_schema",
237237
"@xla//xla/tsl/profiler/utils:xplane_test_utils",
238+
"@xla//xla/tsl/profiler/utils:xplane_utils",
238239
"@xla//xla/tsl/profiler/utils:xplane_visitor",
239240
],
240241
)

xprof/utils/derived_timeline.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -743,10 +743,21 @@ void GenerateDerivedTimeLines(
743743
if (host_plane) {
744744
DeriveEventsFromHostTrace(host_plane, group_metadata_map, device_planes);
745745
}
746-
for (XPlane* plane : FindMutableTensorCorePlanes(space)) {
747-
DeriveLinesFromStats(plane);
748-
tsl::profiler::SortXPlane(plane);
746+
747+
std::vector<XPlane*> tensor_core_planes = FindMutableTensorCorePlanes(space);
748+
749+
int thread_pool_size = std::min(tsl::port::MaxParallelism(),
750+
static_cast<int>(device_planes.size()));
751+
auto plane_processing_executor = std::make_unique<XprofThreadPoolExecutor>(
752+
"ProcessTensorCorePlanes", thread_pool_size);
753+
// TODO(b/449633660) Analyze multi-threading inside DeriveLinesFromStats.
754+
for (XPlane* plane : tensor_core_planes) {
755+
plane_processing_executor->Execute([plane]() {
756+
DeriveLinesFromStats(plane);
757+
tsl::profiler::SortXPlane(plane);
758+
});
749759
}
760+
plane_processing_executor->JoinAll();
750761
}
751762

752763
void DeriveLinesFromStats(XPlane* device_trace) {

xprof/utils/derived_timeline_test.cc

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ limitations under the License.
3636
#include "xla/tsl/profiler/utils/xplane_builder.h"
3737
#include "xla/tsl/profiler/utils/xplane_schema.h"
3838
#include "xla/tsl/profiler/utils/xplane_test_utils.h"
39+
#include "xla/tsl/profiler/utils/xplane_utils.h"
3940
#include "xla/tsl/profiler/utils/xplane_visitor.h"
4041
#include "tsl/profiler/protobuf/xplane.pb.h"
4142

@@ -661,6 +662,78 @@ TEST(DerivedTimelineTest, EnsureAllGpuEventsAreGrouped) {
661662
});
662663
}
663664

665+
// Tests that the multi-threaded processing of Tensor Core planes works
666+
// correctly.
667+
TEST(DerivedTimelineTest, MultiThreadedTensorCorePlaneProcessing) {
668+
constexpr int kNumPlanes = 4;
669+
constexpr int kNumEvents = 10;
670+
constexpr int kEventDurationPs = 100;
671+
672+
XSpace space;
673+
tsl::profiler::GroupMetadataMap group_metadata_map;
674+
675+
// Create multiple Tensor Core planes, each with a line of unsorted events.
676+
for (int i = 0; i < kNumPlanes; ++i) {
677+
XPlane* plane = tsl::profiler::GetOrCreateTpuXPlane(
678+
&space, /*device_ordinal=*/i, "TPU V4", 0, 0);
679+
XPlaneBuilder plane_builder(plane);
680+
auto line_builder = plane_builder.GetOrCreateLine(0);
681+
const std::string tf_op_name = absl::StrCat("MyOp:", i);
682+
683+
for (int j = 0; j < kNumEvents; ++j) {
684+
// Add events in reverse order to test sorting.
685+
int64_t offset = (kNumEvents - 1 - j) * kEventDurationPs * 2;
686+
CreateXEvent(&plane_builder, &line_builder, "kernel", offset,
687+
kEventDurationPs, {{StatType::kTfOp, tf_op_name}});
688+
}
689+
}
690+
691+
// This will trigger the multi-threaded logic you added.
692+
GenerateDerivedTimeLines(group_metadata_map, &space);
693+
694+
// Verify that each plane was processed correctly.
695+
for (int i = 0; i < kNumPlanes; ++i) {
696+
const std::string plane_name = absl::StrCat("/device:TPU:", i);
697+
const XPlane* plane = tsl::profiler::FindPlaneWithName(space, plane_name);
698+
ASSERT_NE(plane, nullptr);
699+
XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(plane);
700+
701+
// 1. Verify that the events on the original line are now sorted.
702+
const XLine* original_line = nullptr;
703+
for (const auto& line : plane->lines()) {
704+
if (line.id() == 0) {
705+
original_line = &line;
706+
break;
707+
}
708+
}
709+
ASSERT_NE(original_line, nullptr);
710+
711+
int64_t last_timestamp_ps = -1;
712+
for (const auto& event : original_line->events()) {
713+
ASSERT_GE(event.offset_ps(), last_timestamp_ps);
714+
last_timestamp_ps = event.offset_ps();
715+
}
716+
EXPECT_EQ(original_line->events_size(), kNumEvents);
717+
718+
// 2. Verify that DeriveLinesFromStats created the derived TF Op line.
719+
bool tf_op_line_found = false;
720+
plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) {
721+
if (line_visitor.Name() == tsl::profiler::kTensorFlowOpLineName) {
722+
tf_op_line_found = true;
723+
EXPECT_EQ(line_visitor.NumEvents(), 1); // Should be merged into one.
724+
line_visitor.ForEachEvent([&](const XEventVisitor& event) {
725+
EXPECT_EQ(event.Name(), absl::StrCat("MyOp:", i));
726+
// Check the duration of the merged event.
727+
int64_t expected_duration =
728+
(kNumEvents - 1) * kEventDurationPs * 2 + kEventDurationPs;
729+
EXPECT_EQ(event.DurationPs(), expected_duration);
730+
});
731+
}
732+
});
733+
EXPECT_TRUE(tf_op_line_found);
734+
}
735+
}
736+
664737
} // namespace
665738
} // namespace profiler
666739
} // namespace tensorflow

0 commit comments

Comments
 (0)