Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions apps/sim/lib/copilot/persistence/tool-confirm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,24 @@ import {
import { getAsyncToolCalls } from '@/lib/copilot/async-runs/repository'
import { MothershipStreamV1ToolOutcome } from '@/lib/copilot/generated/mothership-stream-v1'
import { getRedisClient } from '@/lib/core/config/redis'
import { createPubSubChannel } from '@/lib/events/pubsub'
import { createPubSubChannel, type PubSubChannel } from '@/lib/events/pubsub'

const logger = createLogger('CopilotOrchestratorPersistence')
const TOOL_CONFIRMATION_TTL_SECONDS = 60 * 10
const toolConfirmationKey = (toolCallId: string) => `copilot:tool-confirmation:${toolCallId}`

const toolConfirmationChannel = createPubSubChannel<AsyncCompletionEnvelope>({
channel: 'copilot:tool-confirmation',
label: 'CopilotToolConfirmation',
})
type ToolConfirmGlobal = typeof globalThis & {
_toolConfirmationChannel?: PubSubChannel<AsyncCompletionEnvelope>
}

const _g = globalThis as ToolConfirmGlobal
if (!_g._toolConfirmationChannel) {
_g._toolConfirmationChannel = createPubSubChannel<AsyncCompletionEnvelope>({
channel: 'copilot:tool-confirmation',
label: 'CopilotToolConfirmation',
})
}
const toolConfirmationChannel = _g._toolConfirmationChannel

/**
* Get a tool call confirmation state from the durable async tool row.
Expand Down
20 changes: 15 additions & 5 deletions apps/sim/lib/copilot/tasks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* Channel: `task:status_changed`
*/

import { createPubSubChannel } from '@/lib/events/pubsub'
import { createPubSubChannel, type PubSubChannel } from '@/lib/events/pubsub'

interface TaskStatusEvent {
workspaceId: string
Expand All @@ -16,10 +16,20 @@ interface TaskStatusEvent {
streamId?: string
}

const channel =
typeof window !== 'undefined'
? null
: createPubSubChannel<TaskStatusEvent>({ channel: 'task:status_changed', label: 'task' })
type TaskPubSubGlobal = typeof globalThis & {
_taskStatusChannel?: PubSubChannel<TaskStatusEvent> | null
}

const g = globalThis as TaskPubSubGlobal

if (!('_taskStatusChannel' in g)) {
g._taskStatusChannel =
typeof window !== 'undefined'
? null
: createPubSubChannel<TaskStatusEvent>({ channel: 'task:status_changed', label: 'task' })
}

const channel = g._taskStatusChannel

export const taskPubSub = channel
? {
Expand Down
108 changes: 59 additions & 49 deletions apps/sim/lib/core/config/redis.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,55 +54,65 @@ export function getRedisConnectionDefaults(
}
}

let globalRedisClient: Redis | null = null
let pingFailures = 0
let pingInterval: NodeJS.Timeout | null = null
let pingInFlight = false
interface RedisState {
client: Redis | null
pingFailures: number
pingInterval: NodeJS.Timeout | null
pingInFlight: boolean
reconnectListeners: Array<() => void>
}

const g = globalThis as typeof globalThis & { _redisState?: RedisState }
if (!g._redisState) {
g._redisState = {
client: null,
pingFailures: 0,
pingInterval: null,
pingInFlight: false,
reconnectListeners: [],
}
}
const state = g._redisState

const PING_INTERVAL_MS = 15_000
const MAX_PING_FAILURES = 2

/** Callbacks invoked when the PING health check forces a reconnect. */
const reconnectListeners: Array<() => void> = []

/**
* Register a callback that fires when the PING health check forces a reconnect.
* Useful for resetting cached adapters that hold a stale Redis reference.
*/
export function onRedisReconnect(cb: () => void): void {
reconnectListeners.push(cb)
state.reconnectListeners.push(cb)
}
Comment thread
waleedlatif1 marked this conversation as resolved.

function startPingHealthCheck(redis: Redis): void {
if (pingInterval) return
if (state.pingInterval) return

pingInterval = setInterval(async () => {
if (pingInFlight) return
pingInFlight = true
state.pingInterval = setInterval(async () => {
if (state.pingInFlight) return
state.pingInFlight = true
try {
await redis.ping()
pingFailures = 0
state.pingFailures = 0
} catch (error) {
pingFailures++
state.pingFailures++
logger.warn('Redis PING failed', {
consecutiveFailures: pingFailures,
consecutiveFailures: state.pingFailures,
error: toError(error).message,
})

if (pingFailures >= MAX_PING_FAILURES) {
if (state.pingFailures >= MAX_PING_FAILURES) {
logger.error('Redis PING failed consecutive times — forcing reconnect', {
consecutiveFailures: pingFailures,
consecutiveFailures: state.pingFailures,
})
pingFailures = 0
// Drop the cached client and stop this health check before disconnecting,
// so the next getRedisClient() builds a fresh client and a fresh PING loop.
// Listeners may call getRedisClient() and must observe the cleared global.
globalRedisClient = null
if (pingInterval) {
clearInterval(pingInterval)
pingInterval = null
state.pingFailures = 0
// Clear before notifying listeners — they may call getRedisClient() and must see the reset state.
state.client = null
if (state.pingInterval) {
clearInterval(state.pingInterval)
state.pingInterval = null
}
for (const cb of reconnectListeners) {
for (const cb of state.reconnectListeners) {
try {
cb()
} catch (cbError) {
Expand All @@ -116,7 +126,7 @@ function startPingHealthCheck(redis: Redis): void {
}
}
} finally {
pingInFlight = false
state.pingInFlight = false
}
}, PING_INTERVAL_MS)
}
Expand All @@ -131,15 +141,15 @@ function startPingHealthCheck(redis: Redis): void {
export function getRedisClient(): Redis | null {
if (typeof window !== 'undefined') return null
if (!redisUrl) return null
if (globalRedisClient) return globalRedisClient
if (state.client) return state.client

// Outside the try/catch so config errors aren't silently swallowed.
const defaults = getRedisConnectionDefaults(redisUrl)

try {
logger.info('Initializing Redis client')

globalRedisClient = new Redis(redisUrl, {
state.client = new Redis(redisUrl, {
...defaults,
commandTimeout: 5000,
maxRetriesPerRequest: 5,
Expand All @@ -162,17 +172,17 @@ export function getRedisClient(): Redis | null {
},
})

globalRedisClient.on('connect', () => logger.info('Redis connected'))
globalRedisClient.on('ready', () => logger.info('Redis ready'))
globalRedisClient.on('error', (err: Error) => {
state.client.on('connect', () => logger.info('Redis connected'))
state.client.on('ready', () => logger.info('Redis ready'))
state.client.on('error', (err: Error) => {
logger.error('Redis error', { error: err.message, code: (err as any).code })
})
globalRedisClient.on('close', () => logger.warn('Redis connection closed'))
globalRedisClient.on('end', () => logger.error('Redis connection ended'))
state.client.on('close', () => logger.warn('Redis connection closed'))
state.client.on('end', () => logger.error('Redis connection ended'))

startPingHealthCheck(globalRedisClient)
startPingHealthCheck(state.client)

return globalRedisClient
return state.client
} catch (error) {
logger.error('Failed to initialize Redis client', { error })
return null
Expand Down Expand Up @@ -274,18 +284,18 @@ export async function extendLock(
* Use for graceful shutdown.
*/
export async function closeRedisConnection(): Promise<void> {
if (pingInterval) {
clearInterval(pingInterval)
pingInterval = null
if (state.pingInterval) {
clearInterval(state.pingInterval)
state.pingInterval = null
}

if (globalRedisClient) {
if (state.client) {
try {
await globalRedisClient.quit()
await state.client.quit()
} catch (error) {
logger.error('Error closing Redis connection', { error })
} finally {
globalRedisClient = null
state.client = null
}
}
}
Expand All @@ -294,12 +304,12 @@ export async function closeRedisConnection(): Promise<void> {
* Reset all module-level state. Only intended for use in tests.
*/
export function resetForTesting(): void {
if (pingInterval) {
clearInterval(pingInterval)
pingInterval = null
if (state.pingInterval) {
clearInterval(state.pingInterval)
state.pingInterval = null
}
globalRedisClient = null
pingFailures = 0
pingInFlight = false
reconnectListeners.length = 0
state.client = null
state.pingFailures = 0
state.pingInFlight = false
state.reconnectListeners.length = 0
}
34 changes: 21 additions & 13 deletions apps/sim/lib/core/rate-limiter/storage/factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@ import { RedisTokenBucket } from './redis-token-bucket'

const logger = createLogger('RateLimitStorage')

let cachedAdapter: RateLimitStorageAdapter | null = null
let reconnectListenerRegistered = false
type FactoryGlobal = typeof globalThis & {
_rlCachedAdapter?: RateLimitStorageAdapter | null
_rlReconnectListenerRegistered?: boolean
}

const g = globalThis as FactoryGlobal
if (!('_rlCachedAdapter' in g)) {
g._rlCachedAdapter = null
g._rlReconnectListenerRegistered = false
}

export function createStorageAdapter(): RateLimitStorageAdapter {
if (cachedAdapter) {
return cachedAdapter
if (g._rlCachedAdapter) {
return g._rlCachedAdapter
}

if (!reconnectListenerRegistered) {
if (!g._rlReconnectListenerRegistered) {
onRedisReconnect(() => {
cachedAdapter = null
g._rlCachedAdapter = null
})
reconnectListenerRegistered = true
g._rlReconnectListenerRegistered = true
}

const storageMethod = getStorageMethod()
Expand All @@ -30,27 +38,27 @@ export function createStorageAdapter(): RateLimitStorageAdapter {
logger.warn(
'Redis configured but client unavailable - falling back to PostgreSQL for rate limiting'
)
cachedAdapter = new DbTokenBucket()
g._rlCachedAdapter = new DbTokenBucket()
} else {
logger.info('Rate limiting: Using Redis')
cachedAdapter = new RedisTokenBucket(redis)
g._rlCachedAdapter = new RedisTokenBucket(redis)
}
} else {
logger.info('Rate limiting: Using PostgreSQL')
cachedAdapter = new DbTokenBucket()
g._rlCachedAdapter = new DbTokenBucket()
}

return cachedAdapter
return g._rlCachedAdapter!
}

export function getAdapterType(): StorageMethod {
return getStorageMethod()
}

export function resetStorageAdapter(): void {
cachedAdapter = null
g._rlCachedAdapter = null
}

export function setStorageAdapter(adapter: RateLimitStorageAdapter): void {
cachedAdapter = adapter
g._rlCachedAdapter = adapter
}
12 changes: 8 additions & 4 deletions apps/sim/lib/execution/cancellation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@ export type ExecutionCancellationRecordResult =
reason: 'redis_unavailable' | 'redis_write_failed'
}

let sharedChannel: PubSubChannel<ExecutionCancelEvent> | null = null
type CancellationGlobal = typeof globalThis & {
_executionCancelChannel?: PubSubChannel<ExecutionCancelEvent>
}

const _g = globalThis as CancellationGlobal

export function getCancellationChannel(): PubSubChannel<ExecutionCancelEvent> {
if (!sharedChannel) {
sharedChannel = createPubSubChannel<ExecutionCancelEvent>({
if (!_g._executionCancelChannel) {
_g._executionCancelChannel = createPubSubChannel<ExecutionCancelEvent>({
channel: EXECUTION_CANCEL_CHANNEL,
label: 'execution-cancel',
})
}
return sharedChannel
return _g._executionCancelChannel
}

export function isRedisCancellationEnabled(): boolean {
Expand Down
13 changes: 10 additions & 3 deletions apps/sim/lib/mcp/connection-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,13 @@ export class McpConnectionManager {
}
}

export const mcpConnectionManager: McpConnectionManager | null = isTest
? null
: new McpConnectionManager()
type McpManagerGlobal = typeof globalThis & {
_mcpConnectionManager?: McpConnectionManager | null
}

const _g = globalThis as McpManagerGlobal
if (!('_mcpConnectionManager' in _g)) {
_g._mcpConnectionManager = isTest ? null : new McpConnectionManager()
}

export const mcpConnectionManager: McpConnectionManager | null = _g._mcpConnectionManager ?? null
Loading
Loading