CmdPal: Start extensions in parallel (#39203)

This reduces our extension startup time by approximately 70% on my
machine (I have 17 extensions). I'd guess the gains scale with the
number of extensions. That's 8s -> 3s on average, and now I also get 2.5s reloads.

This retains the order of the list of extensions, by only starting the
processes in parallel. Once we have all the command provider instances,
then actually retrieving the commands.

It also adds a timeout on startup & load, so that one misbehaving extension won't block everyone else.

closes: #38529
This commit is contained in:
Mike Griese 2025-05-02 19:43:31 -05:00 committed by GitHub
parent fe067def65
commit 7e92a9a5e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@
using System.Collections.Immutable;
using System.Collections.ObjectModel;
using System.Diagnostics;
using CommunityToolkit.Mvvm.ComponentModel;
using CommunityToolkit.Mvvm.Input;
using CommunityToolkit.Mvvm.Messaging;
@ -53,53 +54,44 @@ public partial class TopLevelCommandManager : ObservableObject,
{
CommandProviderWrapper wrapper = new(provider, _taskScheduler);
_builtInCommands.Add(wrapper);
await LoadTopLevelCommandsFromProvider(wrapper);
var commands = await LoadTopLevelCommandsFromProvider(wrapper);
lock (TopLevelCommands)
{
foreach (var c in commands)
{
TopLevelCommands.Add(c);
}
}
}
return true;
}
// May be called from a background thread
private async Task LoadTopLevelCommandsFromProvider(CommandProviderWrapper commandProvider)
private async Task<IEnumerable<TopLevelViewModel>> LoadTopLevelCommandsFromProvider(CommandProviderWrapper commandProvider)
{
WeakReference<IPageContext> weakSelf = new(this);
await commandProvider.LoadTopLevelCommands(_serviceProvider, weakSelf);
var settings = _serviceProvider.GetService<SettingsModel>()!;
var makeAndAdd = (ICommandItem? i, bool fallback) =>
List<TopLevelViewModel> commands = [];
foreach (var item in commandProvider.TopLevelItems)
{
var commandItemViewModel = new CommandItemViewModel(new(i), weakSelf);
var topLevelViewModel = new TopLevelViewModel(commandItemViewModel, fallback, commandProvider.ExtensionHost, commandProvider.ProviderId, settings, _serviceProvider);
commands.Add(item);
}
lock (TopLevelCommands)
{
TopLevelCommands.Add(topLevelViewModel);
}
};
await Task.Factory.StartNew(
() =>
{
lock (TopLevelCommands)
{
foreach (var item in commandProvider.TopLevelItems)
{
TopLevelCommands.Add(item);
}
foreach (var item in commandProvider.FallbackItems)
{
TopLevelCommands.Add(item);
}
}
},
CancellationToken.None,
TaskCreationOptions.None,
_taskScheduler);
foreach (var item in commandProvider.FallbackItems)
{
commands.Add(item);
}
commandProvider.CommandsChanged -= CommandProvider_CommandsChanged;
commandProvider.CommandsChanged += CommandProvider_CommandsChanged;
return commands;
}
// By all accounts, we're already on a background thread (the COM call
@ -239,25 +231,71 @@ public partial class TopLevelCommandManager : ObservableObject,
private async Task StartExtensionsAndGetCommands(IEnumerable<IExtensionWrapper> extensions)
{
// TODO This most definitely needs a lock
foreach (var extension in extensions)
{
Logger.LogDebug($"Starting {extension.PackageFullName}");
try
{
// start it ...
await extension.StartExtensionAsync();
var timer = new Stopwatch();
timer.Start();
// ... and fetch the command provider from it.
CommandProviderWrapper wrapper = new(extension, _taskScheduler);
_extensionCommandProviders.Add(wrapper);
await LoadTopLevelCommandsFromProvider(wrapper);
}
catch (Exception ex)
// Start all extensions in parallel
var startTasks = extensions.Select(StartExtensionWithTimeoutAsync);
// Wait for all extensions to start
var wrappers = (await Task.WhenAll(startTasks)).Where(wrapper => wrapper != null).Select(w => w!).ToList();
foreach (var wrapper in wrappers)
{
_extensionCommandProviders.Add(wrapper!);
}
// Load the commands from the providers in parallel
var loadTasks = wrappers.Select(LoadCommandsWithTimeoutAsync);
var commandSets = (await Task.WhenAll(loadTasks)).Where(results => results != null).Select(r => r!).ToList();
lock (TopLevelCommands)
{
foreach (var commands in commandSets)
{
Logger.LogError(ex.ToString());
foreach (var c in commands)
{
TopLevelCommands.Add(c);
}
}
}
timer.Stop();
Logger.LogDebug($"Loading extensions took {timer.ElapsedMilliseconds} ms");
}
private async Task<CommandProviderWrapper?> StartExtensionWithTimeoutAsync(IExtensionWrapper extension)
{
Logger.LogDebug($"Starting {extension.PackageFullName}");
try
{
await extension.StartExtensionAsync().WaitAsync(TimeSpan.FromSeconds(10));
return new CommandProviderWrapper(extension, _taskScheduler);
}
catch (Exception ex)
{
Logger.LogError($"Failed to start extension {extension.PackageFullName}: {ex}");
return null; // Return null for failed extensions
}
}
private async Task<IEnumerable<TopLevelViewModel>?> LoadCommandsWithTimeoutAsync(CommandProviderWrapper wrapper)
{
try
{
return await LoadTopLevelCommandsFromProvider(wrapper!).WaitAsync(TimeSpan.FromSeconds(10));
}
catch (TimeoutException)
{
Logger.LogError($"Loading commands from {wrapper!.ExtensionHost?.Extension?.PackageFullName} timed out");
}
catch (Exception ex)
{
Logger.LogError($"Failed to load commands for extension {wrapper!.ExtensionHost?.Extension?.PackageFullName}: {ex}");
}
return null;
}
private void ExtensionService_OnExtensionRemoved(IExtensionService sender, IEnumerable<IExtensionWrapper> extensions)