Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/apple/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ find_package_torch()

set(_aoti_metal_sources
runtime/metal_backend.cpp
runtime/stats.cpp
runtime/shims/memory.cpp
runtime/shims/et_metal.mm
runtime/shims/et_metal_ops.mm
Expand Down Expand Up @@ -68,6 +69,9 @@ target_link_libraries(

target_compile_options(metal_backend PUBLIC -fexceptions -frtti -fPIC)

# Define C++ preprocessor macro for Metal backend availability
target_compile_definitions(metal_backend PUBLIC ET_BUILD_METAL)

target_link_options(metal_backend PUBLIC -Wl,-export_dynamic)

# Find PyTorch's OpenMP library specifically for libtorch-less AOTI
Expand Down
129 changes: 129 additions & 0 deletions backends/apple/metal/runtime/metal_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <unistd.h>
#include <chrono>
#include <cstdio>

#include <filesystem>
#include <fstream>
#include <mutex>
#include <string>
#include <vector>

Expand All @@ -29,9 +31,92 @@
#include <executorch/backends/apple/metal/runtime/shims/shim_mps.h>
#include <executorch/backends/apple/metal/runtime/shims/tensor_attribute.h>
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
#include <executorch/backends/apple/metal/runtime/stats.h>

namespace executorch::backends::metal {

// Per-method timing statistics
struct MethodStats {
double total_ms = 0.0;
int64_t call_count = 0;
};

// Singleton struct containing all timing statistics and mutex
struct StatsData {
std::mutex mutex;
double execute_total_ms = 0.0;
int64_t execute_call_count = 0;
double init_total_ms = 0.0;
int64_t init_call_count = 0;
std::unordered_map<std::string, MethodStats> method_stats;
std::unordered_map<std::string, MethodStats> init_method_stats;
};

// Thread-safe singleton accessor using C++11 magic statics
static StatsData& get_stats_data() {
static StatsData instance;
return instance;
}

// Accessor functions for execute timing statistics
double get_metal_backend_execute_total_ms() {
auto& stats = get_stats_data();
std::lock_guard<std::mutex> lock(stats.mutex);
return stats.execute_total_ms;
}

int64_t get_metal_backend_execute_call_count() {
auto& stats = get_stats_data();
std::lock_guard<std::mutex> lock(stats.mutex);
return stats.execute_call_count;
}

// Accessor functions for init timing statistics
double get_metal_backend_init_total_ms() {
auto& stats = get_stats_data();
std::lock_guard<std::mutex> lock(stats.mutex);
return stats.init_total_ms;
}

int64_t get_metal_backend_init_call_count() {
auto& stats = get_stats_data();
std::lock_guard<std::mutex> lock(stats.mutex);
return stats.init_call_count;
}

void reset_metal_backend_stats() {
auto& stats = get_stats_data();
std::lock_guard<std::mutex> lock(stats.mutex);
stats.execute_total_ms = 0.0;
stats.execute_call_count = 0;
stats.init_total_ms = 0.0;
stats.init_call_count = 0;
stats.method_stats.clear();
stats.init_method_stats.clear();
}

std::unordered_map<std::string, std::pair<double, int64_t>>
get_metal_backend_per_method_stats() {
auto& stats = get_stats_data();
std::lock_guard<std::mutex> lock(stats.mutex);
std::unordered_map<std::string, std::pair<double, int64_t>> result;
for (const auto& entry : stats.method_stats) {
result[entry.first] = {entry.second.total_ms, entry.second.call_count};
}
return result;
}

std::unordered_map<std::string, std::pair<double, int64_t>>
get_metal_backend_init_per_method_stats() {
auto& stats = get_stats_data();
std::lock_guard<std::mutex> lock(stats.mutex);
std::unordered_map<std::string, std::pair<double, int64_t>> result;
for (const auto& entry : stats.init_method_stats) {
result[entry.first] = {entry.second.total_ms, entry.second.call_count};
}
return result;
}

#define LOAD_SYMBOL(handle, member, name, so_handle) \
do { \
handle->member = reinterpret_cast<name##Func>(dlsym(so_handle, #name)); \
Expand Down Expand Up @@ -137,6 +222,7 @@ class ET_EXPERIMENTAL MetalBackend final
FreeableBuffer* processed, // This will be a empty buffer
ArrayRef<CompileSpec> compile_specs // This will be my empty list
) const override {
auto init_start = std::chrono::high_resolution_clock::now();
ET_LOG(Info, "MetalBackend::init - Starting initialization");

std::string method_name;
Expand Down Expand Up @@ -261,6 +347,27 @@ class ET_EXPERIMENTAL MetalBackend final
}

ET_LOG(Info, "MetalBackend::init - Initialization completed successfully");

// Accumulate init timing statistics
auto init_end = std::chrono::high_resolution_clock::now();
double elapsed_ms =
std::chrono::duration<double, std::milli>(init_end - init_start)
.count();

{
auto& stats_data = get_stats_data();
std::lock_guard<std::mutex> lock(stats_data.mutex);
stats_data.init_total_ms += elapsed_ms;
stats_data.init_call_count++;

// Track per-method init timing
if (!method_name.empty()) {
auto& method_stats = stats_data.init_method_stats[method_name];
method_stats.total_ms += elapsed_ms;
method_stats.call_count++;
}
}

return (DelegateHandle*)handle; // Return the handle post-processing
}

Expand All @@ -269,6 +376,7 @@ class ET_EXPERIMENTAL MetalBackend final
BackendExecutionContext& context,
DelegateHandle* handle_,
Span<EValue*> args) const override {
auto execute_start = std::chrono::high_resolution_clock::now();
ET_LOG(Debug, "MetalBackend execute");

AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;
Expand Down Expand Up @@ -514,6 +622,27 @@ class ET_EXPERIMENTAL MetalBackend final

ET_LOG(Debug, "MetalBackend execution completed successfully");

// Accumulate timing statistics
auto execute_end = std::chrono::high_resolution_clock::now();
double elapsed_ms =
std::chrono::duration<double, std::milli>(execute_end - execute_start)
.count();

{
auto& stats_data = get_stats_data();
std::lock_guard<std::mutex> lock(stats_data.mutex);
stats_data.execute_total_ms += elapsed_ms;
stats_data.execute_call_count++;

// Track per-method timing
const char* method_name = context.get_method_name();
if (method_name != nullptr) {
auto& method_stats = stats_data.method_stats[method_name];
method_stats.total_ms += elapsed_ms;
method_stats.call_count++;
}
}

return Error::Ok;
}

Expand Down
81 changes: 81 additions & 0 deletions backends/apple/metal/runtime/stats.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/apple/metal/runtime/stats.h>
#include <iostream>

namespace executorch {
namespace backends {
namespace metal {

void print_metal_backend_stats() {
std::cout << "\n--- Metal Backend Performance Statistics ---" << std::endl;

// Init stats
double metal_init_total_ms = get_metal_backend_init_total_ms();
int64_t metal_init_call_count = get_metal_backend_init_call_count();
std::cout << "Metal init() total: " << metal_init_total_ms << " ms ("
<< metal_init_call_count << " calls)";
if (metal_init_call_count > 0) {
std::cout << " (avg: " << metal_init_total_ms / metal_init_call_count
<< " ms/call)";
}
std::cout << std::endl;

// Per-method init breakdown
auto init_per_method_stats = get_metal_backend_init_per_method_stats();
if (!init_per_method_stats.empty()) {
std::cout << " Per-method init breakdown:" << std::endl;
for (const auto& entry : init_per_method_stats) {
const std::string& method_name = entry.first;
double method_total_ms = entry.second.first;
int64_t method_call_count = entry.second.second;
std::cout << " " << method_name << ": " << method_total_ms << " ms ("
<< method_call_count << " calls)";
if (method_call_count > 0) {
std::cout << " (avg: " << method_total_ms / method_call_count
<< " ms/call)";
}
std::cout << std::endl;
}
}

// Execute stats
double metal_total_ms = get_metal_backend_execute_total_ms();
int64_t metal_call_count = get_metal_backend_execute_call_count();
std::cout << "\nMetal execute() total: " << metal_total_ms << " ms ("
<< metal_call_count << " calls)";
if (metal_call_count > 0) {
std::cout << " (avg: " << metal_total_ms / metal_call_count << " ms/call)";
}
std::cout << std::endl;

// Per-method execute breakdown
auto per_method_stats = get_metal_backend_per_method_stats();
if (!per_method_stats.empty()) {
std::cout << " Per-method execute breakdown:" << std::endl;
for (const auto& entry : per_method_stats) {
const std::string& method_name = entry.first;
double method_total_ms = entry.second.first;
int64_t method_call_count = entry.second.second;
std::cout << " " << method_name << ": " << method_total_ms << " ms ("
<< method_call_count << " calls)";
if (method_call_count > 0) {
std::cout << " (avg: " << method_total_ms / method_call_count
<< " ms/call)";
}
std::cout << std::endl;
}
}

std::cout << "--------------------------------------------\n" << std::endl;
}

} // namespace metal
} // namespace backends
} // namespace executorch
46 changes: 46 additions & 0 deletions backends/apple/metal/runtime/stats.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>
#include <string>
#include <unordered_map>
#include <utility>

namespace executorch {
namespace backends {
namespace metal {

// =======================
// Metal backend timing statistics
// =======================

// Execute timing
double get_metal_backend_execute_total_ms();
int64_t get_metal_backend_execute_call_count();
// Returns map of method_name -> (total_ms, call_count)
std::unordered_map<std::string, std::pair<double, int64_t>>
get_metal_backend_per_method_stats();

// Init timing
double get_metal_backend_init_total_ms();
int64_t get_metal_backend_init_call_count();
// Returns map of method_name -> (total_ms, call_count) for init
std::unordered_map<std::string, std::pair<double, int64_t>>
get_metal_backend_init_per_method_stats();

// Reset all timing stats
void reset_metal_backend_stats();

// Print all timing stats to stdout
void print_metal_backend_stats();

} // namespace metal
} // namespace backends
} // namespace executorch
7 changes: 7 additions & 0 deletions examples/models/parakeet/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/platform/log.h>
#ifdef ET_BUILD_METAL
#include <executorch/backends/apple/metal/runtime/stats.h>
#endif

DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte).");
DEFINE_string(audio_path, "", "Path to input audio file (.wav).");
Expand Down Expand Up @@ -416,6 +419,10 @@ int main(int argc, char** argv) {
decoded_tokens, *tokenizer);
std::cout << "Transcribed text: " << text << std::endl;

#ifdef ET_BUILD_METAL
executorch::backends::metal::print_metal_backend_stats();
#endif // ET_BUILD_METAL

if (!timestamp_mode.enabled()) {
return 0;
}
Expand Down
Loading