升级XR插件版本

This commit is contained in:
Sora丶kong
2026-03-02 17:56:21 +08:00
parent 8962657674
commit 60f512a9bc
1317 changed files with 110305 additions and 48249 deletions

View File

@@ -0,0 +1,18 @@
using System.Threading.Tasks;
namespace MCPForUnity.Editor.Services.Transport
{
/// <summary>
/// Abstraction for MCP transport implementations (e.g. WebSocket push, stdio).
/// </summary>
public interface IMcpTransportClient
{
bool IsConnected { get; }
string TransportName { get; }
TransportState State { get; }
Task<bool> StartAsync();
Task StopAsync();
Task<bool> VerifyAsync();
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 042446a50a4744170bb294acf827376f
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,450 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MCPForUnity.Editor.Helpers;
using MCPForUnity.Editor.Models;
using MCPForUnity.Editor.Services;
using MCPForUnity.Editor.Tools;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using UnityEditor;
namespace MCPForUnity.Editor.Services.Transport
{
/// <summary>
/// Centralised command execution pipeline shared by all transport implementations.
/// Guarantees that MCP commands are executed on the Unity main thread while preserving
/// the legacy response format expected by the server.
/// </summary>
[InitializeOnLoad]
internal static class TransportCommandDispatcher
{
private static SynchronizationContext _mainThreadContext;
private static int _mainThreadId;
private static int _processingFlag;
private sealed class PendingCommand
{
public PendingCommand(
string commandJson,
TaskCompletionSource<string> completionSource,
CancellationToken cancellationToken,
CancellationTokenRegistration registration)
{
CommandJson = commandJson;
CompletionSource = completionSource;
CancellationToken = cancellationToken;
CancellationRegistration = registration;
QueuedAt = DateTime.UtcNow;
}
public string CommandJson { get; }
public TaskCompletionSource<string> CompletionSource { get; }
public CancellationToken CancellationToken { get; }
public CancellationTokenRegistration CancellationRegistration { get; }
public bool IsExecuting { get; set; }
public DateTime QueuedAt { get; }
public void Dispose()
{
CancellationRegistration.Dispose();
}
public void TrySetResult(string payload)
{
CompletionSource.TrySetResult(payload);
}
public void TrySetCanceled()
{
CompletionSource.TrySetCanceled(CancellationToken);
}
}
private static readonly Dictionary<string, PendingCommand> Pending = new();
private static readonly object PendingLock = new();
private static bool updateHooked;
private static bool initialised;
static TransportCommandDispatcher()
{
// Ensure this runs on the Unity main thread at editor load.
_mainThreadContext = SynchronizationContext.Current;
_mainThreadId = Thread.CurrentThread.ManagedThreadId;
EnsureInitialised();
// Always keep the update hook installed so commands arriving from background
// websocket tasks don't depend on a background-thread event subscription.
if (!updateHooked)
{
updateHooked = true;
EditorApplication.update += ProcessQueue;
}
}
/// <summary>
/// Schedule a command for execution on the Unity main thread and await its JSON response.
/// </summary>
public static Task<string> ExecuteCommandJsonAsync(string commandJson, CancellationToken cancellationToken)
{
if (commandJson is null)
{
throw new ArgumentNullException(nameof(commandJson));
}
EnsureInitialised();
var id = Guid.NewGuid().ToString("N");
var tcs = new TaskCompletionSource<string>(TaskCreationOptions.RunContinuationsAsynchronously);
var registration = cancellationToken.CanBeCanceled
? cancellationToken.Register(() => CancelPending(id, cancellationToken))
: default;
var pending = new PendingCommand(commandJson, tcs, cancellationToken, registration);
lock (PendingLock)
{
Pending[id] = pending;
}
// Proactively wake up the main thread execution loop. This improves responsiveness
// in scenarios where EditorApplication.update is throttled or temporarily not firing
// (e.g., Unity unfocused, compiling, or during domain reload transitions).
RequestMainThreadPump();
return tcs.Task;
}
internal static Task<T> RunOnMainThreadAsync<T>(Func<T> func, CancellationToken cancellationToken)
{
if (func is null)
{
throw new ArgumentNullException(nameof(func));
}
var tcs = new TaskCompletionSource<T>(TaskCreationOptions.RunContinuationsAsynchronously);
var registration = cancellationToken.CanBeCanceled
? cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken))
: default;
void Invoke()
{
try
{
if (tcs.Task.IsCompleted)
{
return;
}
var result = func();
tcs.TrySetResult(result);
}
catch (Exception ex)
{
tcs.TrySetException(ex);
}
finally
{
registration.Dispose();
}
}
// Best-effort nudge: if we're posting from a background thread (e.g., websocket receive),
// encourage Unity to run a loop iteration so the posted callback can execute even when unfocused.
try { EditorApplication.QueuePlayerLoopUpdate(); } catch { }
if (_mainThreadContext != null && Thread.CurrentThread.ManagedThreadId != _mainThreadId)
{
_mainThreadContext.Post(_ => Invoke(), null);
return tcs.Task;
}
Invoke();
return tcs.Task;
}
private static void RequestMainThreadPump()
{
void Pump()
{
try
{
// Hint Unity to run a loop iteration soon.
EditorApplication.QueuePlayerLoopUpdate();
}
catch
{
// Best-effort only.
}
ProcessQueue();
}
if (_mainThreadContext != null && Thread.CurrentThread.ManagedThreadId != _mainThreadId)
{
_mainThreadContext.Post(_ => Pump(), null);
return;
}
Pump();
}
private static void EnsureInitialised()
{
if (initialised)
{
return;
}
CommandRegistry.Initialize();
initialised = true;
}
private static void HookUpdate()
{
// Deprecated: we keep the update hook installed permanently (see static ctor).
if (updateHooked) return;
updateHooked = true;
EditorApplication.update += ProcessQueue;
}
private static void UnhookUpdateIfIdle()
{
// Intentionally no-op: keep update hook installed so background commands always process.
// This avoids "must focus Unity to re-establish contact" edge cases.
return;
}
private static void ProcessQueue()
{
if (Interlocked.Exchange(ref _processingFlag, 1) == 1)
{
return;
}
try
{
List<(string id, PendingCommand pending)> ready;
lock (PendingLock)
{
// Early exit inside lock to prevent per-frame List allocations (GitHub issue #577)
if (Pending.Count == 0)
{
return;
}
ready = new List<(string, PendingCommand)>(Pending.Count);
foreach (var kvp in Pending)
{
if (kvp.Value.IsExecuting)
{
continue;
}
kvp.Value.IsExecuting = true;
ready.Add((kvp.Key, kvp.Value));
}
if (ready.Count == 0)
{
UnhookUpdateIfIdle();
return;
}
}
foreach (var (id, pending) in ready)
{
ProcessCommand(id, pending);
}
}
finally
{
Interlocked.Exchange(ref _processingFlag, 0);
}
}
private static void ProcessCommand(string id, PendingCommand pending)
{
if (pending.CancellationToken.IsCancellationRequested)
{
RemovePending(id, pending);
pending.TrySetCanceled();
return;
}
string commandText = pending.CommandJson?.Trim();
if (string.IsNullOrEmpty(commandText))
{
pending.TrySetResult(SerializeError("Empty command received"));
RemovePending(id, pending);
return;
}
if (string.Equals(commandText, "ping", StringComparison.OrdinalIgnoreCase))
{
var pingResponse = new
{
status = "success",
result = new { message = "pong" }
};
pending.TrySetResult(JsonConvert.SerializeObject(pingResponse));
RemovePending(id, pending);
return;
}
if (!IsValidJson(commandText))
{
var invalidJsonResponse = new
{
status = "error",
error = "Invalid JSON format",
receivedText = commandText.Length > 50 ? commandText[..50] + "..." : commandText
};
pending.TrySetResult(JsonConvert.SerializeObject(invalidJsonResponse));
RemovePending(id, pending);
return;
}
try
{
var command = JsonConvert.DeserializeObject<Command>(commandText);
if (command == null)
{
pending.TrySetResult(SerializeError("Command deserialized to null", "Unknown", commandText));
RemovePending(id, pending);
return;
}
if (string.IsNullOrWhiteSpace(command.type))
{
pending.TrySetResult(SerializeError("Command type cannot be empty"));
RemovePending(id, pending);
return;
}
if (string.Equals(command.type, "ping", StringComparison.OrdinalIgnoreCase))
{
var pingResponse = new
{
status = "success",
result = new { message = "pong" }
};
pending.TrySetResult(JsonConvert.SerializeObject(pingResponse));
RemovePending(id, pending);
return;
}
var parameters = command.@params ?? new JObject();
// Block execution of disabled resources
var resourceMeta = MCPServiceLocator.ResourceDiscovery.GetResourceMetadata(command.type);
if (resourceMeta != null && !MCPServiceLocator.ResourceDiscovery.IsResourceEnabled(command.type))
{
pending.TrySetResult(SerializeError(
$"Resource '{command.type}' is disabled in the Unity Editor."));
RemovePending(id, pending);
return;
}
// Block execution of disabled tools
var toolMeta = MCPServiceLocator.ToolDiscovery.GetToolMetadata(command.type);
if (toolMeta != null && !MCPServiceLocator.ToolDiscovery.IsToolEnabled(command.type))
{
pending.TrySetResult(SerializeError(
$"Tool '{command.type}' is disabled in the Unity Editor."));
RemovePending(id, pending);
return;
}
var result = CommandRegistry.ExecuteCommand(command.type, parameters, pending.CompletionSource);
if (result == null)
{
// Async command cleanup after completion on next editor frame to preserve order.
pending.CompletionSource.Task.ContinueWith(_ =>
{
EditorApplication.delayCall += () => RemovePending(id, pending);
}, TaskScheduler.Default);
return;
}
var response = new { status = "success", result };
pending.TrySetResult(JsonConvert.SerializeObject(response));
RemovePending(id, pending);
}
catch (Exception ex)
{
McpLog.Error($"Error processing command: {ex.Message}\n{ex.StackTrace}");
pending.TrySetResult(SerializeError(ex.Message, "Unknown (error during processing)", ex.StackTrace));
RemovePending(id, pending);
}
}
private static void CancelPending(string id, CancellationToken token)
{
PendingCommand pending = null;
lock (PendingLock)
{
if (Pending.Remove(id, out pending))
{
UnhookUpdateIfIdle();
}
}
pending?.TrySetCanceled();
pending?.Dispose();
}
private static void RemovePending(string id, PendingCommand pending)
{
lock (PendingLock)
{
Pending.Remove(id);
UnhookUpdateIfIdle();
}
pending.Dispose();
}
private static string SerializeError(string message, string commandType = null, string stackTrace = null)
{
var errorResponse = new
{
status = "error",
error = message,
command = commandType ?? "Unknown",
stackTrace
};
return JsonConvert.SerializeObject(errorResponse);
}
private static bool IsValidJson(string text)
{
if (string.IsNullOrWhiteSpace(text))
{
return false;
}
text = text.Trim();
if ((text.StartsWith("{") && text.EndsWith("}")) || (text.StartsWith("[") && text.EndsWith("]")))
{
try
{
JToken.Parse(text);
return true;
}
catch
{
return false;
}
}
return false;
}
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 27407cc9c1ea0412d80b9f8964a5a29d
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,152 @@
using System;
using System.Threading.Tasks;
using MCPForUnity.Editor.Helpers;
using MCPForUnity.Editor.Services.Transport.Transports;
namespace MCPForUnity.Editor.Services.Transport
{
/// <summary>
/// Coordinates the active transport client and exposes lifecycle helpers.
/// </summary>
public class TransportManager
{
private IMcpTransportClient _httpClient;
private IMcpTransportClient _stdioClient;
private TransportState _httpState = TransportState.Disconnected("http");
private TransportState _stdioState = TransportState.Disconnected("stdio");
private Func<IMcpTransportClient> _webSocketFactory;
private Func<IMcpTransportClient> _stdioFactory;
public TransportManager()
{
Configure(
() => new WebSocketTransportClient(MCPServiceLocator.ToolDiscovery),
() => new StdioTransportClient());
}
public void Configure(
Func<IMcpTransportClient> webSocketFactory,
Func<IMcpTransportClient> stdioFactory)
{
_webSocketFactory = webSocketFactory ?? throw new ArgumentNullException(nameof(webSocketFactory));
_stdioFactory = stdioFactory ?? throw new ArgumentNullException(nameof(stdioFactory));
}
private IMcpTransportClient GetOrCreateClient(TransportMode mode)
{
return mode switch
{
TransportMode.Http => _httpClient ??= _webSocketFactory(),
TransportMode.Stdio => _stdioClient ??= _stdioFactory(),
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}
private IMcpTransportClient GetClient(TransportMode mode)
{
return mode switch
{
TransportMode.Http => _httpClient,
TransportMode.Stdio => _stdioClient,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}
public async Task<bool> StartAsync(TransportMode mode)
{
IMcpTransportClient client = GetOrCreateClient(mode);
bool started = await client.StartAsync();
if (!started)
{
try
{
await client.StopAsync();
}
catch (Exception ex)
{
McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}");
}
UpdateState(mode, TransportState.Disconnected(client.TransportName, client.State?.Error ?? "Failed to start"));
return false;
}
UpdateState(mode, client.State ?? TransportState.Connected(client.TransportName));
return true;
}
public async Task StopAsync(TransportMode? mode = null)
{
async Task StopClient(IMcpTransportClient client, TransportMode clientMode)
{
if (client == null) return;
try { await client.StopAsync(); }
catch (Exception ex) { McpLog.Warn($"Error while stopping transport {client.TransportName}: {ex.Message}"); }
finally { UpdateState(clientMode, TransportState.Disconnected(client.TransportName)); }
}
if (mode == null)
{
await StopClient(_httpClient, TransportMode.Http);
await StopClient(_stdioClient, TransportMode.Stdio);
return;
}
if (mode == TransportMode.Http)
{
await StopClient(_httpClient, TransportMode.Http);
}
else
{
await StopClient(_stdioClient, TransportMode.Stdio);
}
}
public async Task<bool> VerifyAsync(TransportMode mode)
{
IMcpTransportClient client = GetClient(mode);
if (client == null)
{
return false;
}
bool ok = await client.VerifyAsync();
var state = client.State ?? TransportState.Disconnected(client.TransportName, "No state reported");
UpdateState(mode, state);
return ok;
}
public TransportState GetState(TransportMode mode)
{
return mode switch
{
TransportMode.Http => _httpState,
TransportMode.Stdio => _stdioState,
_ => throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode"),
};
}
public bool IsRunning(TransportMode mode) => GetState(mode).IsConnected;
private void UpdateState(TransportMode mode, TransportState state)
{
switch (mode)
{
case TransportMode.Http:
_httpState = state;
break;
case TransportMode.Stdio:
_stdioState = state;
break;
default:
throw new ArgumentOutOfRangeException(nameof(mode), mode, "Unsupported transport mode");
}
}
}
public enum TransportMode
{
Http,
Stdio
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 65fc8ff4c9efb4fc98a0910ba7ca8b02
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,52 @@
namespace MCPForUnity.Editor.Services.Transport
{
/// <summary>
/// Lightweight snapshot of a transport's runtime status for editor UI and diagnostics.
/// </summary>
public sealed class TransportState
{
public bool IsConnected { get; }
public string TransportName { get; }
public int? Port { get; }
public string SessionId { get; }
public string Details { get; }
public string Error { get; }
private TransportState(
bool isConnected,
string transportName,
int? port,
string sessionId,
string details,
string error)
{
IsConnected = isConnected;
TransportName = transportName;
Port = port;
SessionId = sessionId;
Details = details;
Error = error;
}
public static TransportState Connected(
string transportName,
int? port = null,
string sessionId = null,
string details = null)
=> new TransportState(true, transportName, port, sessionId, details, null);
public static TransportState Disconnected(
string transportName,
string error = null,
int? port = null)
=> new TransportState(false, transportName, port, null, null, error);
public TransportState WithError(string error) => new TransportState(
IsConnected,
TransportName,
Port,
SessionId,
Details,
error);
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 67ab8e43f6a804698bb5b216cdef0645
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 3d467a63b6fad42fa975c731af4b83b3
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: fd295cefe518e438693c12e9c7f37488
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,50 @@
using System;
using System.Threading.Tasks;
using MCPForUnity.Editor.Helpers;
namespace MCPForUnity.Editor.Services.Transport.Transports
{
/// <summary>
/// Adapts the existing TCP bridge into the transport abstraction.
/// </summary>
public class StdioTransportClient : IMcpTransportClient
{
private TransportState _state = TransportState.Disconnected("stdio");
public bool IsConnected => StdioBridgeHost.IsRunning;
public string TransportName => "stdio";
public TransportState State => _state;
public Task<bool> StartAsync()
{
try
{
StdioBridgeHost.StartAutoConnect();
_state = TransportState.Connected("stdio", port: StdioBridgeHost.GetCurrentPort());
return Task.FromResult(true);
}
catch (Exception ex)
{
_state = TransportState.Disconnected("stdio", ex.Message);
return Task.FromResult(false);
}
}
public Task StopAsync()
{
StdioBridgeHost.Stop();
_state = TransportState.Disconnected("stdio");
return Task.CompletedTask;
}
public Task<bool> VerifyAsync()
{
bool running = StdioBridgeHost.IsRunning;
_state = running
? TransportState.Connected("stdio", port: StdioBridgeHost.GetCurrentPort())
: TransportState.Disconnected("stdio", "Bridge not running");
return Task.FromResult(running);
}
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: b2743f3468d5f433dbf2220f0838d8d1
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@@ -0,0 +1,741 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using MCPForUnity.Editor.Constants;
using MCPForUnity.Editor.Helpers;
using MCPForUnity.Editor.Services;
using MCPForUnity.Editor.Services.Transport;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using UnityEditor;
using UnityEngine;
namespace MCPForUnity.Editor.Services.Transport.Transports
{
/// <summary>
/// Maintains a persistent WebSocket connection to the MCP server plugin hub.
/// Handles registration, keep-alives, and command dispatch back into Unity via
/// <see cref="TransportCommandDispatcher"/>.
/// </summary>
public class WebSocketTransportClient : IMcpTransportClient, IDisposable
{
private const string TransportDisplayName = "websocket";
private static readonly TimeSpan[] ReconnectSchedule =
{
TimeSpan.Zero,
TimeSpan.FromSeconds(1),
TimeSpan.FromSeconds(3),
TimeSpan.FromSeconds(5),
TimeSpan.FromSeconds(10),
TimeSpan.FromSeconds(30)
};
private static readonly TimeSpan DefaultKeepAliveInterval = TimeSpan.FromSeconds(15);
private static readonly TimeSpan DefaultCommandTimeout = TimeSpan.FromSeconds(30);
private readonly IToolDiscoveryService _toolDiscoveryService;
private ClientWebSocket _socket;
private CancellationTokenSource _lifecycleCts;
private CancellationTokenSource _connectionCts;
private Task _receiveTask;
private Task _keepAliveTask;
private readonly SemaphoreSlim _sendLock = new(1, 1);
private Uri _endpointUri;
private string _sessionId;
private string _projectHash;
private string _projectName;
private string _projectPath;
private string _unityVersion;
private TimeSpan _keepAliveInterval = DefaultKeepAliveInterval;
private TimeSpan _socketKeepAliveInterval = DefaultKeepAliveInterval;
private volatile bool _isConnected;
private int _isReconnectingFlag;
private TransportState _state = TransportState.Disconnected(TransportDisplayName, "Transport not started");
private string _apiKey;
private bool _disposed;
public WebSocketTransportClient(IToolDiscoveryService toolDiscoveryService = null)
{
_toolDiscoveryService = toolDiscoveryService;
}
public bool IsConnected => _isConnected;
public string TransportName => TransportDisplayName;
public TransportState State => _state;
private Task<List<ToolMetadata>> GetEnabledToolsOnMainThreadAsync(CancellationToken token)
{
return TransportCommandDispatcher.RunOnMainThreadAsync(
() => _toolDiscoveryService?.GetEnabledTools() ?? new List<ToolMetadata>(),
token);
}
public async Task<bool> StartAsync()
{
// Capture identity values on the main thread before any async context switching
_projectName = ProjectIdentityUtility.GetProjectName();
_projectHash = ProjectIdentityUtility.GetProjectHash();
_unityVersion = Application.unityVersion;
_apiKey = HttpEndpointUtility.IsRemoteScope()
? EditorPrefs.GetString(EditorPrefKeys.ApiKey, string.Empty)
: string.Empty;
// Get project root path (strip /Assets from dataPath) for focus nudging
string dataPath = Application.dataPath;
if (!string.IsNullOrEmpty(dataPath))
{
string normalized = dataPath.TrimEnd('/', '\\');
if (string.Equals(System.IO.Path.GetFileName(normalized), "Assets", StringComparison.Ordinal))
{
_projectPath = System.IO.Path.GetDirectoryName(normalized) ?? normalized;
}
else
{
_projectPath = normalized; // Fallback if path doesn't end with Assets
}
}
await StopAsync();
_lifecycleCts = new CancellationTokenSource();
_endpointUri = BuildWebSocketUri(HttpEndpointUtility.GetBaseUrl());
_sessionId = null;
if (!await EstablishConnectionAsync(_lifecycleCts.Token))
{
await StopAsync();
return false;
}
// State is connected but session ID might be pending until 'registered' message
_state = TransportState.Connected(TransportDisplayName, sessionId: "pending", details: _endpointUri.ToString());
_isConnected = true;
return true;
}
public async Task StopAsync()
{
if (_lifecycleCts == null)
{
return;
}
try
{
_lifecycleCts.Cancel();
}
catch { }
await StopConnectionLoopsAsync().ConfigureAwait(false);
if (_socket != null)
{
try
{
if (_socket.State == WebSocketState.Open || _socket.State == WebSocketState.CloseReceived)
{
await _socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Shutdown", CancellationToken.None).ConfigureAwait(false);
}
}
catch { }
finally
{
_socket.Dispose();
_socket = null;
}
}
_isConnected = false;
_state = TransportState.Disconnected(TransportDisplayName);
_lifecycleCts.Dispose();
_lifecycleCts = null;
}
public async Task<bool> VerifyAsync()
{
if (_socket == null || _socket.State != WebSocketState.Open)
{
return false;
}
if (_lifecycleCts == null)
{
return false;
}
try
{
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(_lifecycleCts.Token);
timeoutCts.CancelAfter(TimeSpan.FromSeconds(5));
await SendPongAsync(timeoutCts.Token).ConfigureAwait(false);
return true;
}
catch (Exception ex)
{
McpLog.Warn($"[WebSocket] Verify ping failed: {ex.Message}");
return false;
}
}
public void Dispose()
{
if (_disposed)
{
return;
}
try
{
// Ensure background loops are stopped before disposing shared resources
StopAsync().GetAwaiter().GetResult();
}
catch (Exception ex)
{
McpLog.Warn($"[WebSocket] Dispose failed to stop cleanly: {ex.Message}");
}
_sendLock?.Dispose();
_socket?.Dispose();
_lifecycleCts?.Dispose();
_disposed = true;
}
private async Task<bool> EstablishConnectionAsync(CancellationToken token)
{
await StopConnectionLoopsAsync().ConfigureAwait(false);
_connectionCts?.Dispose();
_connectionCts = CancellationTokenSource.CreateLinkedTokenSource(token);
CancellationToken connectionToken = _connectionCts.Token;
_socket?.Dispose();
_socket = new ClientWebSocket();
_socket.Options.KeepAliveInterval = _socketKeepAliveInterval;
// Add API key header if configured (for remote-hosted mode)
if (!string.IsNullOrEmpty(_apiKey))
{
_socket.Options.SetRequestHeader(AuthConstants.ApiKeyHeader, _apiKey);
}
try
{
await _socket.ConnectAsync(_endpointUri, connectionToken).ConfigureAwait(false);
}
catch (Exception ex)
{
string errorMsg = "Connection failed. Check that the server URL is correct, the server is running, and your API key (if required) is valid.";
McpLog.Error($"[WebSocket] {errorMsg} (Detail: {ex.Message})");
_state = TransportState.Disconnected(TransportDisplayName, errorMsg);
return false;
}
StartBackgroundLoops(connectionToken);
try
{
await SendRegisterAsync(connectionToken).ConfigureAwait(false);
}
catch (Exception ex)
{
string regMsg = $"Registration with server failed: {ex.Message}";
McpLog.Error($"[WebSocket] {regMsg}");
_state = TransportState.Disconnected(TransportDisplayName, regMsg);
return false;
}
return true;
}
/// <summary>
/// Stops the connection loops and disposes of the connection CTS.
/// Particularly useful when reconnecting, we want to ensure that background loops are cancelled correctly before starting new oens
/// </summary>
/// <param name="awaitTasks">Whether to await the receive and keep alive tasks before disposing.</param>
private async Task StopConnectionLoopsAsync(bool awaitTasks = true)
{
if (_connectionCts != null && !_connectionCts.IsCancellationRequested)
{
try { _connectionCts.Cancel(); } catch { }
}
if (_receiveTask != null)
{
if (awaitTasks)
{
try { await _receiveTask.ConfigureAwait(false); } catch { }
_receiveTask = null;
}
else if (_receiveTask.IsCompleted)
{
_receiveTask = null;
}
}
if (_keepAliveTask != null)
{
if (awaitTasks)
{
try { await _keepAliveTask.ConfigureAwait(false); } catch { }
_keepAliveTask = null;
}
else if (_keepAliveTask.IsCompleted)
{
_keepAliveTask = null;
}
}
if (_connectionCts != null)
{
_connectionCts.Dispose();
_connectionCts = null;
}
}
private void StartBackgroundLoops(CancellationToken token)
{
if ((_receiveTask != null && !_receiveTask.IsCompleted) || (_keepAliveTask != null && !_keepAliveTask.IsCompleted))
{
return;
}
_receiveTask = Task.Run(() => ReceiveLoopAsync(token), CancellationToken.None);
_keepAliveTask = Task.Run(() => KeepAliveLoopAsync(token), CancellationToken.None);
}
private async Task ReceiveLoopAsync(CancellationToken token)
{
while (!token.IsCancellationRequested)
{
try
{
string message = await ReceiveMessageAsync(token).ConfigureAwait(false);
if (message == null)
{
continue;
}
await HandleMessageAsync(message, token).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
break;
}
catch (WebSocketException wse)
{
McpLog.Warn($"[WebSocket] Receive loop error: {wse.Message}");
await HandleSocketClosureAsync(wse.Message).ConfigureAwait(false);
break;
}
catch (Exception ex)
{
McpLog.Warn($"[WebSocket] Unexpected receive error: {ex.Message}");
await HandleSocketClosureAsync(ex.Message).ConfigureAwait(false);
break;
}
}
}
private async Task<string> ReceiveMessageAsync(CancellationToken token)
{
if (_socket == null)
{
return null;
}
byte[] rentedBuffer = System.Buffers.ArrayPool<byte>.Shared.Rent(8192);
var buffer = new ArraySegment<byte>(rentedBuffer);
using var ms = new MemoryStream(8192);
try
{
while (!token.IsCancellationRequested)
{
WebSocketReceiveResult result = await _socket.ReceiveAsync(buffer, token).ConfigureAwait(false);
if (result.MessageType == WebSocketMessageType.Close)
{
await HandleSocketClosureAsync(result.CloseStatusDescription ?? "Server closed connection").ConfigureAwait(false);
return null;
}
if (result.Count > 0)
{
ms.Write(buffer.Array!, buffer.Offset, result.Count);
}
if (result.EndOfMessage)
{
break;
}
}
if (ms.Length == 0)
{
return null;
}
return Encoding.UTF8.GetString(ms.ToArray());
}
finally
{
System.Buffers.ArrayPool<byte>.Shared.Return(rentedBuffer);
}
}
private async Task HandleMessageAsync(string message, CancellationToken token)
{
JObject payload;
try
{
payload = JObject.Parse(message);
}
catch (Exception ex)
{
McpLog.Warn($"[WebSocket] Invalid JSON payload: {ex.Message}");
return;
}
string messageType = payload.Value<string>("type") ?? string.Empty;
switch (messageType)
{
case "welcome":
ApplyWelcome(payload);
break;
case "registered":
await HandleRegisteredAsync(payload, token).ConfigureAwait(false);
break;
case "execute":
await HandleExecuteAsync(payload, token).ConfigureAwait(false);
break;
case "ping":
await SendPongAsync(token).ConfigureAwait(false);
break;
default:
// No-op for unrecognised types (keep-alives, telemetry, etc.)
break;
}
}
private void ApplyWelcome(JObject payload)
{
int? keepAliveSeconds = payload.Value<int?>("keepAliveInterval");
if (keepAliveSeconds.HasValue && keepAliveSeconds.Value > 0)
{
_keepAliveInterval = TimeSpan.FromSeconds(keepAliveSeconds.Value);
_socketKeepAliveInterval = _keepAliveInterval;
}
int? serverTimeoutSeconds = payload.Value<int?>("serverTimeout");
if (serverTimeoutSeconds.HasValue)
{
int sourceSeconds = keepAliveSeconds ?? serverTimeoutSeconds.Value;
int safeSeconds = Math.Max(5, Math.Min(serverTimeoutSeconds.Value, sourceSeconds));
_socketKeepAliveInterval = TimeSpan.FromSeconds(safeSeconds);
}
}
private async Task HandleRegisteredAsync(JObject payload, CancellationToken token)
{
string newSessionId = payload.Value<string>("session_id");
if (!string.IsNullOrEmpty(newSessionId))
{
_sessionId = newSessionId;
ProjectIdentityUtility.SetSessionId(_sessionId);
_state = TransportState.Connected(TransportDisplayName, sessionId: _sessionId, details: _endpointUri.ToString());
McpLog.Info($"[WebSocket] Registered with session ID: {_sessionId}", false);
await SendRegisterToolsAsync(token).ConfigureAwait(false);
}
}
private async Task SendRegisterToolsAsync(CancellationToken token)
{
if (_toolDiscoveryService == null) return;
token.ThrowIfCancellationRequested();
var tools = await GetEnabledToolsOnMainThreadAsync(token).ConfigureAwait(false);
token.ThrowIfCancellationRequested();
McpLog.Info($"[WebSocket] Preparing to register {tools.Count} tool(s) with the bridge.", false);
var toolsArray = new JArray();
foreach (var tool in tools)
{
var toolObj = new JObject
{
["name"] = tool.Name,
["description"] = tool.Description,
["structured_output"] = tool.StructuredOutput,
["requires_polling"] = tool.RequiresPolling,
["poll_action"] = tool.PollAction
};
var paramsArray = new JArray();
if (tool.Parameters != null)
{
foreach (var p in tool.Parameters)
{
paramsArray.Add(new JObject
{
["name"] = p.Name,
["description"] = p.Description,
["type"] = p.Type,
["required"] = p.Required,
["default_value"] = p.DefaultValue
});
}
}
toolObj["parameters"] = paramsArray;
toolsArray.Add(toolObj);
}
var payload = new JObject
{
["type"] = "register_tools",
["tools"] = toolsArray
};
await SendJsonAsync(payload, token).ConfigureAwait(false);
McpLog.Info($"[WebSocket] Sent {tools.Count} tools registration", false);
}
private async Task HandleExecuteAsync(JObject payload, CancellationToken token)
{
string commandId = payload.Value<string>("id");
string commandName = payload.Value<string>("name");
JObject parameters = payload.Value<JObject>("params") ?? new JObject();
int timeoutSeconds = payload.Value<int?>("timeout") ?? (int)DefaultCommandTimeout.TotalSeconds;
if (string.IsNullOrEmpty(commandId) || string.IsNullOrEmpty(commandName))
{
McpLog.Warn("[WebSocket] Invalid execute payload (missing id or name)");
return;
}
var commandEnvelope = new JObject
{
["type"] = commandName,
["params"] = parameters
};
string responseJson;
try
{
using var timeoutCts = CancellationTokenSource.CreateLinkedTokenSource(token);
timeoutCts.CancelAfter(TimeSpan.FromSeconds(Math.Max(1, timeoutSeconds)));
responseJson = await TransportCommandDispatcher.ExecuteCommandJsonAsync(commandEnvelope.ToString(Formatting.None), timeoutCts.Token).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
responseJson = JsonConvert.SerializeObject(new
{
status = "error",
error = $"Command '{commandName}' timed out after {timeoutSeconds} seconds"
});
}
catch (Exception ex)
{
responseJson = JsonConvert.SerializeObject(new
{
status = "error",
error = ex.Message
});
}
JToken resultToken;
try
{
resultToken = JToken.Parse(responseJson);
}
catch
{
resultToken = new JObject
{
["status"] = "error",
["error"] = "Invalid response payload"
};
}
var responsePayload = new JObject
{
["type"] = "command_result",
["id"] = commandId,
["result"] = resultToken
};
await SendJsonAsync(responsePayload, token).ConfigureAwait(false);
}
private async Task KeepAliveLoopAsync(CancellationToken token)
{
while (!token.IsCancellationRequested)
{
try
{
await Task.Delay(_keepAliveInterval, token).ConfigureAwait(false);
if (_socket == null || _socket.State != WebSocketState.Open)
{
break;
}
await SendPongAsync(token).ConfigureAwait(false);
}
catch (OperationCanceledException)
{
break;
}
catch (Exception ex)
{
McpLog.Warn($"[WebSocket] Keep-alive failed: {ex.Message}");
await HandleSocketClosureAsync(ex.Message).ConfigureAwait(false);
break;
}
}
}
private async Task SendRegisterAsync(CancellationToken token)
{
var registerPayload = new JObject
{
["type"] = "register",
// session_id is now server-authoritative; omitted here or sent as null
["project_name"] = _projectName,
["project_hash"] = _projectHash,
["unity_version"] = _unityVersion,
["project_path"] = _projectPath
};
await SendJsonAsync(registerPayload, token).ConfigureAwait(false);
}
private Task SendPongAsync(CancellationToken token)
{
var payload = new JObject
{
["type"] = "pong",
["session_id"] = _sessionId // Include session ID for server-side tracking
};
return SendJsonAsync(payload, token);
}
private async Task SendJsonAsync(JObject payload, CancellationToken token)
{
if (_socket == null)
{
throw new InvalidOperationException("WebSocket is not initialised");
}
string json = payload.ToString(Formatting.None);
byte[] bytes = Encoding.UTF8.GetBytes(json);
var buffer = new ArraySegment<byte>(bytes);
await _sendLock.WaitAsync(token).ConfigureAwait(false);
try
{
if (_socket.State != WebSocketState.Open)
{
throw new InvalidOperationException("WebSocket is not open");
}
await _socket.SendAsync(buffer, WebSocketMessageType.Text, true, token).ConfigureAwait(false);
}
finally
{
_sendLock.Release();
}
}
private async Task HandleSocketClosureAsync(string reason)
{
// Capture stack trace for debugging disconnection triggers
var stackTrace = new System.Diagnostics.StackTrace(true);
McpLog.Debug($"[WebSocket] HandleSocketClosureAsync called. Reason: {reason}\nStack trace:\n{stackTrace}");
if (_lifecycleCts == null || _lifecycleCts.IsCancellationRequested)
{
return;
}
if (Interlocked.CompareExchange(ref _isReconnectingFlag, 1, 0) != 0)
{
return;
}
_isConnected = false;
_state = _state.WithError(reason ?? "Connection closed");
McpLog.Warn($"[WebSocket] Connection closed: {reason}");
await StopConnectionLoopsAsync(awaitTasks: false).ConfigureAwait(false);
_ = Task.Run(() => AttemptReconnectAsync(_lifecycleCts.Token), CancellationToken.None);
}
private async Task AttemptReconnectAsync(CancellationToken token)
{
try
{
await StopConnectionLoopsAsync().ConfigureAwait(false);
foreach (TimeSpan delay in ReconnectSchedule)
{
if (token.IsCancellationRequested)
{
return;
}
if (delay > TimeSpan.Zero)
{
try { await Task.Delay(delay, token).ConfigureAwait(false); }
catch (OperationCanceledException) { return; }
}
if (await EstablishConnectionAsync(token).ConfigureAwait(false))
{
_state = TransportState.Connected(TransportDisplayName, sessionId: _sessionId, details: _endpointUri.ToString());
_isConnected = true;
McpLog.Info("[WebSocket] Reconnected to MCP server", false);
return;
}
}
}
finally
{
Interlocked.Exchange(ref _isReconnectingFlag, 0);
}
_state = TransportState.Disconnected(TransportDisplayName, "Failed to reconnect");
}
private static Uri BuildWebSocketUri(string baseUrl)
{
if (!Uri.TryCreate(baseUrl, UriKind.Absolute, out var httpUri))
{
throw new InvalidOperationException($"Invalid MCP base URL: {baseUrl}");
}
// Replace bind-only addresses with localhost for client connections
// 0.0.0.0 and :: are only valid for server binding, not client connections
string host = httpUri.Host;
if (host == "0.0.0.0" || host == "::")
{
McpLog.Warn($"[WebSocket] Base URL host '{host}' is bind-only; using 'localhost' for client connection.");
host = "localhost";
}
var builder = new UriBuilder(httpUri)
{
Scheme = httpUri.Scheme.Equals("https", StringComparison.OrdinalIgnoreCase) ? "wss" : "ws",
Host = host,
Path = httpUri.AbsolutePath.TrimEnd('/') + "/hub/plugin"
};
return builder.Uri;
}
}
}

View File

@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 044c8f7beb4af4a77a14d677190c21dc
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant: