Skip to content

Commit

Permalink
feat: add chain Id to extension connection/disconnection
Browse files Browse the repository at this point in the history
  • Loading branch information
jurevans committed Sep 26, 2024
1 parent 5df8577 commit 6a3b09e
Show file tree
Hide file tree
Showing 11 changed files with 93 additions and 47 deletions.
10 changes: 8 additions & 2 deletions apps/extension/src/Approvals/ApproveConnection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@ export const ApproveConnection: React.FC = () => {
const requester = useRequester();
const params = useQuery();
const interfaceOrigin = params.get("interfaceOrigin");
const chainId = params.get("chainId")!;

const handleResponse = async (allowConnection: boolean): Promise<void> => {
if (interfaceOrigin) {
await requester.sendMessage(
Ports.Background,
new ConnectInterfaceResponseMsg(interfaceOrigin, allowConnection)
new ConnectInterfaceResponseMsg(
interfaceOrigin,
chainId,
allowConnection
)
);
await closeCurrentTab();
}
Expand All @@ -26,7 +31,8 @@ export const ApproveConnection: React.FC = () => {
<PageHeader title="Approve Request" />
<Stack full className="justify-between" gap={12}>
<Alert type="warning">
Approve connection for <strong>{interfaceOrigin}</strong>?
Approve connection for <strong>{interfaceOrigin}</strong> and enable
signing for <strong>{chainId}</strong>?
</Alert>
<Stack gap={2}>
<ActionButton onClick={() => handleResponse(true)}>
Expand Down
6 changes: 4 additions & 2 deletions apps/extension/src/background/approvals/handler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ describe("approvals handler", () => {
});

test("handlers switch", () => {
const chainId = "";
const service: jest.Mocked<ApprovalsService> = createMockInstance(
ApprovalsService as any
);
Expand Down Expand Up @@ -72,16 +73,17 @@ describe("approvals handler", () => {
handler(env, rejectTxMsg);
expect(service.rejectSignTx).toBeCalled();

const isConnectionApprovedMsg = new IsConnectionApprovedMsg();
const isConnectionApprovedMsg = new IsConnectionApprovedMsg(chainId);
handler(env, isConnectionApprovedMsg);
expect(service.isConnectionApproved).toBeCalled();

const approveConnectInterfaceMsg = new ApproveConnectInterfaceMsg();
const approveConnectInterfaceMsg = new ApproveConnectInterfaceMsg(chainId);
handler(env, approveConnectInterfaceMsg);
expect(service.approveConnection).toBeCalled();

const connectInterfaceResponseMsg = new ConnectInterfaceResponseMsg(
"",
chainId,
true
);
handler(env, connectInterfaceResponseMsg);
Expand Down
12 changes: 6 additions & 6 deletions apps/extension/src/background/approvals/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,16 @@ export const getHandler: (service: ApprovalsService) => Handler = (service) => {
const handleIsConnectionApprovedMsg: (
service: ApprovalsService
) => InternalHandler<IsConnectionApprovedMsg> = (service) => {
return async (_, { origin }) => {
return await service.isConnectionApproved(origin);
return async (_, { origin, chainId }) => {
return await service.isConnectionApproved(origin, chainId);
};
};

const handleApproveConnectInterfaceMsg: (
service: ApprovalsService
) => InternalHandler<ApproveConnectInterfaceMsg> = (service) => {
return async (_, { origin }) => {
return await service.approveConnection(origin);
return async (_, { origin, chainId }) => {
return await service.approveConnection(origin, chainId);
};
};

Expand All @@ -148,8 +148,8 @@ const handleConnectInterfaceResponseMsg: (
const handleApproveDisconnectInterfaceMsg: (
service: ApprovalsService
) => InternalHandler<ApproveDisconnectInterfaceMsg> = (service) => {
return async (_, { origin }) => {
return await service.approveDisconnection(origin);
return async (_, { origin, chainId }) => {
return await service.approveDisconnection(origin, chainId);
};
};

Expand Down
6 changes: 3 additions & 3 deletions apps/extension/src/background/approvals/messages.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,20 @@ describe("approvals messages", () => {
});

test("valid ConnectInterfaceResponseMsg", () => {
const msg = new ConnectInterfaceResponseMsg("interface", true);
const msg = new ConnectInterfaceResponseMsg("interface", "chainId", true);

expect(msg.type()).toBe(MessageType.ConnectInterfaceResponse);
expect(msg.route()).toBe(ROUTE);
expect(msg.validate()).toBeUndefined();
});

test("invalid ConnectInterfaceResponseMsg", () => {
const msg = new ConnectInterfaceResponseMsg("interface", true);
const msg = new ConnectInterfaceResponseMsg("interface", "chainId", true);
(msg as any).interfaceOrigin = undefined;

expect(() => msg.validate()).toThrow();

const msg2 = new ConnectInterfaceResponseMsg("interface", true);
const msg2 = new ConnectInterfaceResponseMsg("interface", "chainId", true);
(msg2 as any).allowConnection = undefined;

expect(() => msg2.validate()).toThrow();
Expand Down
3 changes: 2 additions & 1 deletion apps/extension/src/background/approvals/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,14 @@ export class ConnectInterfaceResponseMsg extends Message<void> {

constructor(
public readonly interfaceOrigin: string,
public readonly chainId: string,
public readonly allowConnection: boolean
) {
super();
}

validate(): void {
validateProps(this, ["interfaceOrigin", "allowConnection"]);
validateProps(this, ["interfaceOrigin", "chainId", "allowConnection"]);
}

route(): string {
Expand Down
27 changes: 20 additions & 7 deletions apps/extension/src/background/approvals/service.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,14 @@ describe("approvals service", () => {
describe("approveConnection", () => {
it("should approve connection if it's not already approved", async () => {
const interfaceOrigin = "origin";
const chainId = "chainId";
const tabId = 1;

jest.spyOn(service, "isConnectionApproved").mockResolvedValue(false);
jest.spyOn(service as any, "launchApprovalPopup");
service["resolverMap"] = {};

const promise = service.approveConnection(interfaceOrigin);
const promise = service.approveConnection(interfaceOrigin, chainId);
await new Promise<void>((r) =>
setTimeout(() => {
r();
Expand All @@ -290,10 +291,11 @@ describe("approvals service", () => {

it("should not approve connection if it was already approved", async () => {
const interfaceOrigin = "origin";
const chainId = "chainId";
jest.spyOn(service, "isConnectionApproved").mockResolvedValue(true);

await expect(
service.approveConnection(interfaceOrigin)
service.approveConnection(interfaceOrigin, chainId)
).resolves.toBeUndefined();
});
});
Expand Down Expand Up @@ -354,13 +356,14 @@ describe("approvals service", () => {
describe("approveDisconnection", () => {
it("should approve disconnection if there is a connection already approved", async () => {
const interfaceOrigin = "origin";
const chainId = "";
const tabId = 1;

jest.spyOn(service, "isConnectionApproved").mockResolvedValue(true);
jest.spyOn(service as any, "launchApprovalPopup");
service["resolverMap"] = {};

const promise = service.approveDisconnection(interfaceOrigin);
const promise = service.approveDisconnection(interfaceOrigin, chainId);
await new Promise<void>((r) =>
setTimeout(() => {
r();
Expand All @@ -380,10 +383,11 @@ describe("approvals service", () => {

it("should not approve disconnection if it is NOT already approved", async () => {
const interfaceOrigin = "origin";
const chainId = "";
jest.spyOn(service, "isConnectionApproved").mockResolvedValue(false);

await expect(
service.approveDisconnection(interfaceOrigin)
service.approveDisconnection(interfaceOrigin, chainId)
).resolves.toBeUndefined();
});
});
Expand Down Expand Up @@ -543,27 +547,36 @@ describe("approvals service", () => {
describe("isConnectionApproved", () => {
it("should return true if origin is approved", async () => {
const origin = "origin";
const chainId = "chainId";
jest
.spyOn(localStorage, "getApprovedOrigins")
.mockResolvedValue([origin]);

await expect(service.isConnectionApproved(origin)).resolves.toBe(true);
await expect(service.isConnectionApproved(origin, chainId)).resolves.toBe(
true
);
});

it("should return false if origin is not approved", async () => {
const origin = "origin";
const chainId = "chainId";
jest.spyOn(localStorage, "getApprovedOrigins").mockResolvedValue([]);

await expect(service.isConnectionApproved(origin)).resolves.toBe(false);
await expect(service.isConnectionApproved(origin, chainId)).resolves.toBe(
false
);
});

it("should return false if there are no origins in store", async () => {
const origin = "origin";
const chainId = "chainId";
jest
.spyOn(localStorage, "getApprovedOrigins")
.mockResolvedValue(undefined);

await expect(service.isConnectionApproved(origin)).resolves.toBe(false);
await expect(service.isConnectionApproved(origin, chainId)).resolves.toBe(
false
);
});
});
});
31 changes: 26 additions & 5 deletions apps/extension/src/background/approvals/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,19 +184,34 @@ export class ApprovalsService {
resolvers.reject(new Error("Sign Tx rejected"));
}

async isConnectionApproved(interfaceOrigin: string): Promise<boolean> {
async isConnectionApproved(
interfaceOrigin: string,
chainId: string
): Promise<boolean> {
const approvedOrigins =
(await this.localStorage.getApprovedOrigins()) || [];

const chain = await this.chainService.getChain();
if (chain.chainId !== chainId) {
return false;
}

return approvedOrigins.includes(interfaceOrigin);
}

async approveConnection(interfaceOrigin: string): Promise<void> {
const alreadyApproved = await this.isConnectionApproved(interfaceOrigin);
async approveConnection(
interfaceOrigin: string,
chainId: string
): Promise<void> {
const alreadyApproved = await this.isConnectionApproved(
interfaceOrigin,
chainId
);

if (!alreadyApproved) {
return this.launchApprovalPopup(TopLevelRoute.ApproveConnection, {
interfaceOrigin,
chainId,
});
}

Expand All @@ -223,8 +238,14 @@ export class ApprovalsService {
}
}

async approveDisconnection(interfaceOrigin: string): Promise<void> {
const isConnected = await this.isConnectionApproved(interfaceOrigin);
async approveDisconnection(
interfaceOrigin: string,
chainId: string
): Promise<void> {
const isConnected = await this.isConnectionApproved(
interfaceOrigin,
chainId
);

if (isConnected) {
return this.launchApprovalPopup(TopLevelRoute.ApproveDisconnection, {
Expand Down
12 changes: 6 additions & 6 deletions apps/extension/src/provider/Namada.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,28 @@ export class Namada implements INamada {
protected readonly requester?: MessageRequester
) {}

public async connect(): Promise<void> {
public async connect(chainId: string): Promise<void> {
return await this.requester?.sendMessage(
Ports.Background,
new ApproveConnectInterfaceMsg()
new ApproveConnectInterfaceMsg(chainId)
);
}

public async disconnect(): Promise<void> {
public async disconnect(chainId: string): Promise<void> {
return await this.requester?.sendMessage(
Ports.Background,
new ApproveDisconnectInterfaceMsg(location.origin)
new ApproveDisconnectInterfaceMsg(location.origin, chainId)
);
}

public async isConnected(): Promise<boolean> {
public async isConnected(chainId: string): Promise<boolean> {
if (!this.requester) {
throw new Error("no requester");
}

return await this.requester.sendMessage(
Ports.Background,
new IsConnectionApprovedMsg()
new IsConnectionApprovedMsg(chainId)
);
}

Expand Down
15 changes: 9 additions & 6 deletions apps/extension/src/provider/messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ export class IsConnectionApprovedMsg extends Message<boolean> {
return MessageType.IsConnectionApproved;
}

constructor() {
constructor(public readonly chainId: string) {
super();
}

validate(): void {
return;
validateProps(this, ["chainId"]);
}

route(): string {
Expand All @@ -118,12 +118,12 @@ export class ApproveConnectInterfaceMsg extends Message<void> {
return MessageType.ApproveConnectInterface;
}

constructor() {
constructor(public readonly chainId: string) {
super();
}

validate(): void {
return;
validateProps(this, ["chainId"]);
}

route(): string {
Expand All @@ -140,12 +140,15 @@ export class ApproveDisconnectInterfaceMsg extends Message<void> {
return MessageType.ApproveDisconnectInterface;
}

constructor(public readonly originToRevoke: string) {
constructor(
public readonly originToRevoke: string,
public readonly chainId: string
) {
super();
}

validate(): void {
validateProps(this, ["originToRevoke"]);
validateProps(this, ["originToRevoke", "chainId"]);
}

route(): string {
Expand Down
12 changes: 6 additions & 6 deletions packages/integrations/src/Namada.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ export default class Namada implements Integration<Account, Signer> {
return !!this._namada;
}

public async connect(): Promise<void> {
await this._namada?.connect();
public async connect(chainId: string): Promise<void> {
await this._namada?.connect(chainId);
}

public async disconnect(): Promise<void> {
await this._namada?.disconnect();
public async disconnect(chainId: string): Promise<void> {
await this._namada?.disconnect(chainId);
}

public async isConnected(): Promise<boolean | undefined> {
return await this._namada?.isConnected();
public async isConnected(chainId: string): Promise<boolean | undefined> {
return await this._namada?.isConnected(chainId);
}

public async getChain(): Promise<Chain | undefined> {
Expand Down
Loading

0 comments on commit 6a3b09e

Please sign in to comment.