diff --git a/src/commands/wallet.ts b/src/commands/wallet.ts index 66d68e4..335c747 100644 --- a/src/commands/wallet.ts +++ b/src/commands/wallet.ts @@ -11,6 +11,7 @@ import { formatChainId, formatChainIds } from "../lib/chains"; import { c } from "../lib/color"; import { openBrowser } from "../lib/browser"; import { selectOption, prompt } from "../lib/prompt"; +import { withApprovalGate } from "../lib/walletGate"; import qrcode from "qrcode-terminal"; export function registerWalletCommands(program: Command): void { @@ -37,10 +38,8 @@ export function registerWalletCommands(program: Command): void { .action(async (opts, cmd) => { const json = isJson(cmd); try { - const provider = await createProviderAdapter(); - const signature = await provider.signMessage( - Number(opts.chainId), - opts.message + const signature = await withApprovalGate((provider) => + provider.signMessage(Number(opts.chainId), opts.message) ); outputResult(json, { signature }); } catch (err) { @@ -82,10 +81,8 @@ export function registerWalletCommands(program: Command): void { ); } - const provider = await createProviderAdapter(); - const signature = await provider.signTypedData( - Number(opts.chainId), - typedData + const signature = await withApprovalGate((provider) => + provider.signTypedData(Number(opts.chainId), typedData) ); outputResult(json, { signature }); } catch (err) { @@ -112,16 +109,6 @@ export function registerWalletCommands(program: Command): void { ); } - const provider = await createProviderAdapter(); - const supportedChainIds = await provider.getSupportedChainIds(); - if (!supportedChainIds.includes(chainId)) { - throw new CliError( - `Unsupported chain ID: ${formatChainId(chainId)}`, - "VALIDATION_ERROR", - `Supported chains: ${formatChainIds(supportedChainIds)}` - ); - } - if (!isAddress(opts.to)) { throw new CliError( `Invalid --to address: ${opts.to}`, @@ -151,10 +138,20 @@ export function registerWalletCommands(program: Command): void { } } - const transactionHash = await provider.sendTransaction(chainId, { - to: opts.to, - ...(opts.data !== undefined ? { data: opts.data } : {}), - ...(value !== undefined ? { value } : {}), + const transactionHash = await withApprovalGate(async (provider) => { + const supportedChainIds = await provider.getSupportedChainIds(); + if (!supportedChainIds.includes(chainId)) { + throw new CliError( + `Unsupported chain ID: ${formatChainId(chainId)}`, + "VALIDATION_ERROR", + `Supported chains: ${formatChainIds(supportedChainIds)}` + ); + } + return provider.sendTransaction(chainId, { + to: opts.to, + ...(opts.data !== undefined ? { data: opts.data } : {}), + ...(value !== undefined ? { value } : {}), + }); }); outputResult(json, { transactionHash }); } catch (err) { @@ -368,9 +365,9 @@ export function registerWalletCommands(program: Command): void { if (!json && isTTY()) { process.stdout.write(" Signing wallet verification..."); } - signature = await provider.signMessage( - chainId, - initResult.data.challenge + const challenge = initResult.data.challenge; + signature = await withApprovalGate((p) => + p.signMessage(chainId, challenge) ); if (!json && isTTY()) { console.log(` ${c.green("✓")}`); diff --git a/src/lib/agentFactory.ts b/src/lib/agentFactory.ts index eda2302..c396b08 100644 --- a/src/lib/agentFactory.ts +++ b/src/lib/agentFactory.ts @@ -11,7 +11,10 @@ import { SseTransport, AcpApiClient, } from "@virtuals-protocol/acp-node-v2"; -import type { IEvmProviderAdapter } from "@virtuals-protocol/acp-node-v2"; +import type { + IEvmProviderAdapter, + SupportedStreams, +} from "@virtuals-protocol/acp-node-v2"; import { getBuilderCode, getActiveWallet, @@ -164,6 +167,31 @@ export async function createProviderAdapter(): Promise { return createProviderFromConfig(chains, serverUrl, privyAppId); } +export async function createSseTransport( + provider: IEvmProviderAdapter, + streams: SupportedStreams[] +): Promise { + const isTestnet = process.env.IS_TESTNET === "true"; + const serverUrl = isTestnet ? ACP_TESTNET_SERVER_URL : ACP_SERVER_URL; + const [agentAddress, providerSupportedChainIds] = await Promise.all([ + provider.getAddress(), + provider.getSupportedChainIds(), + ]); + + const ctx = { + agentAddress, + contractAddresses: ACP_CONTRACT_ADDRESSES, + providerSupportedChainIds, + signTypedData: (chainId, typedData) => + provider.signTypedData(chainId, typedData), + } as Parameters[0]; + + const transport = new SseTransport({ serverUrl }); + transport.setContext(ctx); + await transport.connect(undefined, streams); + return transport; +} + export function getWalletAddress(): string { const addr = getActiveWallet(); if (!addr) { diff --git a/src/lib/walletGate.ts b/src/lib/walletGate.ts new file mode 100644 index 0000000..6b4769d --- /dev/null +++ b/src/lib/walletGate.ts @@ -0,0 +1,17 @@ +import { + STREAMS, + type IEvmProviderAdapter, +} from "@virtuals-protocol/acp-node-v2"; +import { createProviderAdapter, createSseTransport } from "./agentFactory"; + +export async function withApprovalGate( + fn: (provider: IEvmProviderAdapter) => Promise +): Promise { + const provider = await createProviderAdapter(); + const transport = await createSseTransport(provider, [STREAMS.WALLET]); + try { + return await fn(provider); + } finally { + void Promise.resolve(transport.disconnect()).catch(() => {}); + } +}