diff --git a/package.json b/package.json index e200158a2d..fd8db54e16 100644 --- a/package.json +++ b/package.json @@ -95,6 +95,7 @@ "highlight.js": "^11.3.1", "html-entities": "^2.0.0", "is-ip": "^3.1.0", + "js-xxhash": "^3.0.1", "jszip": "^3.7.0", "katex": "^0.16.0", "linkify-element": "4.1.3", diff --git a/playwright/e2e/crypto/staged-rollout.spec.ts b/playwright/e2e/crypto/staged-rollout.spec.ts new file mode 100644 index 0000000000..a735444d1e --- /dev/null +++ b/playwright/e2e/crypto/staged-rollout.spec.ts @@ -0,0 +1,64 @@ +/* +Copyright 2024 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import { test, expect } from "../../element-web-test"; +import { logIntoElement } from "./utils"; + +test.describe("Migration of existing logins", () => { + test("Test migration of existing logins when rollout is 100%", async ({ + page, + context, + app, + credentials, + homeserver, + }, workerInfo) => { + test.skip(workerInfo.project.name === "Rust Crypto", "This test only works with Rust crypto."); + await page.goto("/#/login"); + + let featureRustCrypto = false; + let stagedRolloutPercent = 0; + + await context.route(`http://localhost:8080/config.json*`, async (route) => { + const json = {}; + json["features"] = { + feature_rust_crypto: featureRustCrypto, + }; + json["setting_defaults"] = { + "RustCrypto.staged_rollout_percent": stagedRolloutPercent, + }; + await route.fulfill({ json }); + }); + + await logIntoElement(page, homeserver, credentials); + + await app.settings.openUserSettings("Help & About"); + await expect(page.getByText("Crypto version: Olm")).toBeVisible(); + + featureRustCrypto = true; + + await page.reload(); + + await app.settings.openUserSettings("Help & About"); + await expect(page.getByText("Crypto version: Olm")).toBeVisible(); + + stagedRolloutPercent = 100; + + await page.reload(); + + await app.settings.openUserSettings("Help & About"); + await expect(page.getByText("Crypto version: Rust SDK")).toBeVisible(); + }); +}); diff --git a/src/MatrixClientPeg.ts b/src/MatrixClientPeg.ts index d09b8467fd..fd3f05a7cf 100644 --- a/src/MatrixClientPeg.ts +++ b/src/MatrixClientPeg.ts @@ -18,15 +18,15 @@ limitations under the License. */ import { - ICreateClientOpts, - PendingEventOrdering, - RoomNameState, - RoomNameType, EventTimeline, EventTimelineSet, + ICreateClientOpts, IStartClientOpts, MatrixClient, MemoryStore, + PendingEventOrdering, + RoomNameState, + RoomNameType, TokenRefreshFunction, } from "matrix-js-sdk/src/matrix"; import * as utils from "matrix-js-sdk/src/utils"; @@ -53,6 +53,7 @@ import PlatformPeg from "./PlatformPeg"; import { formatList } from "./utils/FormattingUtils"; import SdkConfig from "./SdkConfig"; import { Features } from "./settings/Settings"; +import { PhasedRolloutFeature } from "./utils/PhasedRolloutFeature"; export interface IMatrixClientCreds { homeserverUrl: string; @@ -302,13 +303,34 @@ class MatrixClientPegClass implements IMatrixClientPeg { throw new Error("createClient must be called first"); } - const useRustCrypto = SettingsStore.getValue(Features.RustCrypto); + let useRustCrypto = SettingsStore.getValue(Features.RustCrypto); + + // We want the value that is set in the config.json for that web instance + const defaultUseRustCrypto = SettingsStore.getValueAt(SettingLevel.CONFIG, Features.RustCrypto); + const migrationPercent = SettingsStore.getValueAt(SettingLevel.CONFIG, "RustCrypto.staged_rollout_percent"); + + // If the default config is to use rust crypto, and the user is on legacy crypto, + // we want to check if we should migrate the current user. + if (!useRustCrypto && defaultUseRustCrypto && Number.isInteger(migrationPercent)) { + // The user is not on rust crypto, but the default stack is now rust; Let's check if we should migrate + // the current user to rust crypto. + try { + const stagedRollout = new PhasedRolloutFeature("RustCrypto.staged_rollout_percent", migrationPercent); + // Device id should not be null at that point, or init crypto will fail anyhow + const deviceId = this.matrixClient.getDeviceId()!; + // we use deviceId rather than userId because we don't particularly want all devices + // of a user to be migrated at the same time. + useRustCrypto = stagedRollout.isFeatureEnabled(deviceId); + } catch (e) { + logger.warn("Failed to create staged rollout feature for rust crypto migration", e); + } + } // we want to make sure that the same crypto implementation is used throughout the lifetime of a device, // so persist the setting at the device layer // (At some point, we'll allow the user to *enable* the setting via labs, which will migrate their existing // device to the rust-sdk implementation, but that won't change anything here). - await SettingsStore.setValue("feature_rust_crypto", null, SettingLevel.DEVICE, useRustCrypto); + await SettingsStore.setValue(Features.RustCrypto, null, SettingLevel.DEVICE, useRustCrypto); // Now we can initialise the right crypto impl. if (useRustCrypto) { diff --git a/src/settings/Settings.tsx b/src/settings/Settings.tsx index d93e83a2d4..3f7013c495 100644 --- a/src/settings/Settings.tsx +++ b/src/settings/Settings.tsx @@ -96,6 +96,7 @@ export enum Features { VoiceBroadcastForceSmallChunks = "feature_voice_broadcast_force_small_chunks", NotificationSettings2 = "feature_notification_settings2", OidcNativeFlow = "feature_oidc_native_flow", + // If true, every new login will use the new rust crypto implementation RustCrypto = "feature_rust_crypto", } @@ -503,6 +504,13 @@ export const SETTINGS: { [setting: string]: ISetting } = { default: false, controller: new RustCryptoSdkController(), }, + // Must be set under `setting_defaults` in config.json. + // If set to 100 in conjunction with `feature_rust_crypto`, all existing users will migrate to the new crypto. + // Default is 0, meaning no existing users on legacy crypto will migrate. + "RustCrypto.staged_rollout_percent": { + supportedLevels: [SettingLevel.CONFIG], + default: 0, + }, "baseFontSize": { displayName: _td("settings|appearance|font_size"), supportedLevels: LEVELS_ACCOUNT_SETTINGS, diff --git a/src/utils/PhasedRolloutFeature.ts b/src/utils/PhasedRolloutFeature.ts new file mode 100644 index 0000000000..c305390389 --- /dev/null +++ b/src/utils/PhasedRolloutFeature.ts @@ -0,0 +1,63 @@ +/* +Copyright 2024 New Vector Ltd + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import { xxHash32 } from "js-xxhash"; + +/** + * The PhasedRolloutFeature class is used to manage the phased rollout of a new feature. + * + * It uses a hash of the user's identifier and the feature name to determine if a feature is enabled for a specific user. + * The rollout percentage determines the probability that a user will be enabled for the feature. + * The feature will be enabled for all users if the rollout percentage is 100, and for no users if the percentage is 0. + * If a user is enabled for a feature at x% rollout, it will also be for any greater than x percent. + * + * The process ensures a uniform distribution of enabled features across users. + * + * @property featureName - The name of the feature to be rolled out. + * @property rolloutPercentage - The int percentage (0..100) of users for whom the feature should be enabled. + */ +export class PhasedRolloutFeature { + public readonly featureName: string; + private readonly rolloutPercentage: number; + private readonly seed: number; + + public constructor(featureName: string, rolloutPercentage: number) { + this.featureName = featureName; + if (!Number.isInteger(rolloutPercentage) || rolloutPercentage < 0 || rolloutPercentage > 100) { + throw new Error("Rollout percentage must be an integer between 0 and 100"); + } + this.rolloutPercentage = rolloutPercentage; + // We add the feature name for the seed to ensure that the hash is different for each feature + this.seed = Array.from(featureName).reduce((sum, char) => sum + char.charCodeAt(0), 0); + } + + /** + * Returns true if the feature should be enabled for the given user. + * @param userIdentifier - Some unique identifier for the user, e.g. their user ID or device ID. + */ + public isFeatureEnabled(userIdentifier: string): boolean { + /* + * We use a hash function to convert the unique user ID string into an integer. + * This integer can then be used as a basis for deciding whether the user should have access to the new feature. + * We need some hash with good uniform distribution properties, security is not a concern here. + * We use xxHash32, which is fast and has good distribution properties. + */ + const hash = xxHash32(userIdentifier, this.seed); + // We use the hash modulo 100 to get a number between 0 and 99. + // Modulo is simple and effective and the distribution should be uniform enough for our purposes. + return hash % 100 < this.rolloutPercentage; + } +} diff --git a/test/MatrixClientPeg-test.ts b/test/MatrixClientPeg-test.ts index 8b137b310b..1f937ba4f8 100644 --- a/test/MatrixClientPeg-test.ts +++ b/test/MatrixClientPeg-test.ts @@ -144,6 +144,117 @@ describe("MatrixClientPeg", () => { expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, true); }); + describe("Rust staged rollout", () => { + function mockSettingStore( + userIsUsingRust: boolean, + newLoginShouldUseRust: boolean, + rolloutPercent: number | null, + ) { + const originalGetValue = SettingsStore.getValue; + jest.spyOn(SettingsStore, "getValue").mockImplementation( + (settingName: string, roomId: string | null = null, excludeDefault = false) => { + if (settingName === "feature_rust_crypto") { + return userIsUsingRust; + } + return originalGetValue(settingName, roomId, excludeDefault); + }, + ); + const originalGetValueAt = SettingsStore.getValueAt; + jest.spyOn(SettingsStore, "getValueAt").mockImplementation( + (level: SettingLevel, settingName: string) => { + if (settingName === "feature_rust_crypto") { + return newLoginShouldUseRust; + } + // if null we let the original implementation handle it to get the default + if (settingName === "RustCrypto.staged_rollout_percent" && rolloutPercent !== null) { + return rolloutPercent; + } + return originalGetValueAt(level, settingName); + }, + ); + } + + let mockSetValue: jest.SpyInstance; + let mockInitCrypto: jest.SpyInstance; + let mockInitRustCrypto: jest.SpyInstance; + + beforeEach(() => { + mockSetValue = jest.spyOn(SettingsStore, "setValue").mockResolvedValue(undefined); + mockInitCrypto = jest.spyOn(testPeg.safeGet(), "initCrypto").mockResolvedValue(undefined); + mockInitRustCrypto = jest.spyOn(testPeg.safeGet(), "initRustCrypto").mockResolvedValue(undefined); + }); + + it("Should not migrate existing login if rollout is 0", async () => { + mockSettingStore(false, true, 0); + + await testPeg.start(); + expect(mockInitCrypto).toHaveBeenCalled(); + expect(mockInitRustCrypto).not.toHaveBeenCalledTimes(1); + + // we should have stashed the setting in the settings store + expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, false); + }); + + it("Should migrate existing login if rollout is 100", async () => { + mockSettingStore(false, true, 100); + await testPeg.start(); + expect(mockInitCrypto).not.toHaveBeenCalled(); + expect(mockInitRustCrypto).toHaveBeenCalledTimes(1); + + // we should have stashed the setting in the settings store + expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, true); + }); + + it("Should migrate existing login if user is in rollout bucket", async () => { + mockSettingStore(false, true, 30); + + // Use a device id that is known to be in the 30% bucket (hash modulo 100 < 30) + const spy = jest.spyOn(testPeg.get()!, "getDeviceId").mockReturnValue("AAA"); + + await testPeg.start(); + expect(mockInitCrypto).not.toHaveBeenCalled(); + expect(mockInitRustCrypto).toHaveBeenCalledTimes(1); + + // we should have stashed the setting in the settings store + expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, true); + + spy.mockReset(); + }); + + it("Should not migrate existing login if rollout is malformed", async () => { + mockSettingStore(false, true, 100.1); + + await testPeg.start(); + expect(mockInitCrypto).toHaveBeenCalled(); + expect(mockInitRustCrypto).not.toHaveBeenCalledTimes(1); + + // we should have stashed the setting in the settings store + expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, false); + }); + + it("Default is to not migrate", async () => { + mockSettingStore(false, true, null); + + await testPeg.start(); + expect(mockInitCrypto).toHaveBeenCalled(); + expect(mockInitRustCrypto).not.toHaveBeenCalledTimes(1); + + // we should have stashed the setting in the settings store + expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, false); + }); + + it("Should not migrate if feature_rust_crypto is false", async () => { + mockSettingStore(false, false, 100); + + await testPeg.start(); + expect(mockInitCrypto).toHaveBeenCalled(); + expect(mockInitRustCrypto).not.toHaveBeenCalledTimes(1); + + // we should have stashed the setting in the settings store + expect(mockSetValue).toHaveBeenCalledWith("feature_rust_crypto", null, SettingLevel.DEVICE, false); + }); + }); + it("should reload when store database closes for a guest user", async () => { testPeg.safeGet().isGuest = () => true; const emitter = new EventEmitter(); diff --git a/test/utils/PhasedRolloutFeature-test.ts b/test/utils/PhasedRolloutFeature-test.ts new file mode 100644 index 0000000000..d76e8c3b94 --- /dev/null +++ b/test/utils/PhasedRolloutFeature-test.ts @@ -0,0 +1,89 @@ +/* +Copyright 2024 The Matrix.org Foundation C.I.C. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +import { PhasedRolloutFeature } from "../../src/utils/PhasedRolloutFeature"; + +describe("Test PhasedRolloutFeature", () => { + function randomUserId() { + const characters = "abcdefghijklmnopqrstuvwxyz0123456789.=_-/+"; + let result = ""; + const charactersLength = characters.length; + const idLength = Math.floor(Math.random() * 15) + 6; // Random number between 6 and 20 + for (let i = 0; i < idLength; i++) { + result += characters.charAt(Math.floor(Math.random() * charactersLength)); + } + return "@" + result + ":matrix.org"; + } + + function randomDeviceId() { + const characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + let result = ""; + const charactersLength = characters.length; + for (let i = 0; i < 10; i++) { + result += characters.charAt(Math.floor(Math.random() * charactersLength)); + } + return result; + } + + it("should only accept valid percentage", () => { + expect(() => new PhasedRolloutFeature("test", 0.8)).toThrow(); + expect(() => new PhasedRolloutFeature("test", -1)).toThrow(); + expect(() => new PhasedRolloutFeature("test", 123)).toThrow(); + }); + + it("should enable for all if percentage is 100", () => { + const phasedRolloutFeature = new PhasedRolloutFeature("test", 100); + + for (let i = 0; i < 1000; i++) { + expect(phasedRolloutFeature.isFeatureEnabled(randomUserId())).toBeTruthy(); + } + }); + + it("should not enable for anyone if percentage is 0", () => { + const phasedRolloutFeature = new PhasedRolloutFeature("test", 0); + + for (let i = 0; i < 1000; i++) { + expect(phasedRolloutFeature.isFeatureEnabled(randomUserId())).toBeFalsy(); + } + }); + + it("should enable for more users if percentage grows", () => { + let rolloutPercentage = 0; + let previousBatch: string[] = []; + const allUsers = new Array(1000).fill(0).map(() => randomDeviceId()); + + while (rolloutPercentage <= 90) { + rolloutPercentage += 10; + const nextRollout = new PhasedRolloutFeature("test", rolloutPercentage); + const nextBatch = allUsers.filter((userId) => nextRollout.isFeatureEnabled(userId)); + expect(previousBatch.length).toBeLessThan(nextBatch.length); + expect(previousBatch.every((user) => nextBatch.includes(user))).toBeTruthy(); + previousBatch = nextBatch; + } + }); + + it("should distribute differently depending on the feature name", () => { + const allUsers = new Array(1000).fill(0).map(() => randomUserId()); + + const featureARollout = new PhasedRolloutFeature("FeatureA", 50); + const featureBRollout = new PhasedRolloutFeature("FeatureB", 50); + + const featureAUsers = allUsers.filter((userId) => featureARollout.isFeatureEnabled(userId)); + const featureBUsers = allUsers.filter((userId) => featureBRollout.isFeatureEnabled(userId)); + + expect(featureAUsers).not.toEqual(featureBUsers); + }); +}); diff --git a/yarn.lock b/yarn.lock index d1665d9229..cf49259d53 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6527,6 +6527,11 @@ jest@^29.6.2: resolved "https://registry.yarnpkg.com/js-tokens/-/js-tokens-4.0.0.tgz#19203fb59991df98e3a287050d4647cdeaf32499" integrity sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ== +js-xxhash@^3.0.1: + version "3.0.1" + resolved "https://registry.yarnpkg.com/js-xxhash/-/js-xxhash-3.0.1.tgz#e093b53d02cd80a830d61f58290c206aaa877b24" + integrity sha512-Y2NSC77RIxJrvi2NoXjMi2LYsVDTlVqBoQRi8PXQg4PtP29wdtIOhsp8Ujw4EjEkBFheCPx8bMOmI9zoxx/3jQ== + js-yaml@^3.13.1: version "3.14.1" resolved "https://registry.yarnpkg.com/js-yaml/-/js-yaml-3.14.1.tgz#dae812fdb3825fa306609a8717383c50c36a0537"