Merge branch 'updated-user-config-handling' into disappearing-message-redesign

pull/941/head
Ryan Zhao 11 months ago
commit 0fc00ab527

@ -6649,7 +6649,7 @@
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 416;
CURRENT_PROJECT_VERSION = 418;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
DEVELOPMENT_TEAM = SUQ8J2PCT7;
FRAMEWORK_SEARCH_PATHS = "$(inherited)";
@ -6721,7 +6721,7 @@
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 416;
CURRENT_PROJECT_VERSION = 418;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
DEVELOPMENT_TEAM = SUQ8J2PCT7;
ENABLE_NS_ASSERTIONS = NO;
@ -6786,7 +6786,7 @@
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 416;
CURRENT_PROJECT_VERSION = 418;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
DEVELOPMENT_TEAM = SUQ8J2PCT7;
FRAMEWORK_SEARCH_PATHS = "$(inherited)";
@ -6860,7 +6860,7 @@
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
CODE_SIGN_STYLE = Automatic;
COPY_PHASE_STRIP = NO;
CURRENT_PROJECT_VERSION = 416;
CURRENT_PROJECT_VERSION = 418;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
DEVELOPMENT_TEAM = SUQ8J2PCT7;
ENABLE_NS_ASSERTIONS = NO;
@ -7768,7 +7768,7 @@
CODE_SIGN_ENTITLEMENTS = Session/Meta/Signal.entitlements;
CODE_SIGN_IDENTITY = "iPhone Developer";
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
CURRENT_PROJECT_VERSION = 416;
CURRENT_PROJECT_VERSION = 418;
DEVELOPMENT_TEAM = SUQ8J2PCT7;
FRAMEWORK_SEARCH_PATHS = (
"$(inherited)",
@ -7839,7 +7839,7 @@
CODE_SIGN_ENTITLEMENTS = Session/Meta/Signal.entitlements;
CODE_SIGN_IDENTITY = "iPhone Developer";
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
CURRENT_PROJECT_VERSION = 416;
CURRENT_PROJECT_VERSION = 418;
DEVELOPMENT_TEAM = SUQ8J2PCT7;
FRAMEWORK_SEARCH_PATHS = (
"$(inherited)",

@ -356,8 +356,12 @@ final class NewClosedGroupVC: BaseVC, UITableViewDataSource, UITableViewDelegate
}
},
receiveValue: { thread in
self?.presentingViewController?.dismiss(animated: true, completion: nil)
SessionApp.presentConversation(for: thread.id, action: .compose, animated: false)
SessionApp.presentConversationCreatingIfNeeded(
for: thread.id,
variant: thread.variant,
dismissing: self?.presentingViewController,
animated: false
)
}
)
}

@ -213,6 +213,7 @@ public class ConversationViewModel: OWSAudioPlayerDelegate {
// MARK: - Interaction Data
private var lastInteractionIdMarkedAsRead: Int64? = nil
private var lastInteractionTimestampMsMarkedAsRead: Int64 = 0
public private(set) var unobservedInteractionDataChanges: ([SectionModel], StagedChangeset<[SectionModel]>)?
public private(set) var interactionData: [SectionModel] = []
@ -645,8 +646,8 @@ public class ConversationViewModel: OWSAudioPlayerDelegate {
/// Since this method now gets triggered when scrolling we want to try to optimise it and avoid busying the database
/// write queue when it isn't needed, in order to do this we:
/// - Throttle the updates to 100ms (quick enough that users shouldn't notice, but will help the DB when the user flings the list)
/// - Don't bother marking anything as read if this was called with the same `interactionId` that we previously marked as
/// read (ie. when scrolling and the last message hasn't changed)
/// - Only mark interactions as read if they have newer `timestampMs` or `id` values (ie. were sent later or were more-recent
/// entries in the database), **Note:** Old messages will be marked as read upon insertion so shouldn't be an issue
///
/// The `ThreadViewModel.markAsRead` method also tries to avoid marking as read if a conversation is already fully read
if markAsReadPublisher == nil {
@ -656,10 +657,11 @@ public class ConversationViewModel: OWSAudioPlayerDelegate {
receiveOutput: { [weak self] target, timestampMs in
switch target {
case .thread: self?.threadData.markAsRead(target: target)
case .threadAndInteractions:
case .threadAndInteractions(let interactionId):
guard
timestampMs == nil ||
(self?.lastInteractionTimestampMsMarkedAsRead ?? 0) < (timestampMs ?? 0)
(self?.lastInteractionTimestampMsMarkedAsRead ?? 0) < (timestampMs ?? 0) ||
(self?.lastInteractionIdMarkedAsRead ?? 0) < (interactionId ?? 0)
else {
self?.threadData.markAsRead(target: .thread)
return
@ -671,6 +673,7 @@ public class ConversationViewModel: OWSAudioPlayerDelegate {
self?.lastInteractionTimestampMsMarkedAsRead = timestampMs
}
self?.lastInteractionIdMarkedAsRead = (interactionId ?? self?.threadData.interactionId)
self?.threadData.markAsRead(target: target)
}
}

@ -179,16 +179,13 @@ final class NewConversationVC: BaseVC, ThemedNavigation, UITableViewDelegate, UI
tableView.deselectRow(at: indexPath, animated: true)
let sessionId = newConversationViewModel.sectionData[indexPath.section].contacts[indexPath.row].id
let maybeThread: SessionThread? = Storage.shared.write { db in
try SessionThread
.fetchOrCreate(db, id: sessionId, variant: .contact, shouldBeVisible: nil)
}
guard maybeThread != nil else { return }
self.navigationController?.dismiss(animated: true, completion: nil)
SessionApp.presentConversation(for: sessionId, action: .compose, animated: false)
SessionApp.presentConversationCreatingIfNeeded(
for: sessionId,
variant: .contact,
dismissing: navigationController,
animated: false
)
}
func tableView(_ tableView: UITableView, willDisplayHeaderView view: UIView, forSection section: Int) {

@ -260,16 +260,12 @@ final class NewDMVC: BaseVC, UIPageViewControllerDataSource, UIPageViewControlle
}
private func startNewDM(with sessionId: String) {
let maybeThread: SessionThread? = Storage.shared.write { db in
try SessionThread
.fetchOrCreate(db, id: sessionId, variant: .contact, shouldBeVisible: nil)
}
guard maybeThread != nil else { return }
presentingViewController?.dismiss(animated: true, completion: nil)
SessionApp.presentConversation(for: sessionId, action: .compose, animated: false)
SessionApp.presentConversationCreatingIfNeeded(
for: sessionId,
variant: .contact,
dismissing: presentingViewController,
animated: false
)
}
}

@ -143,6 +143,7 @@ class AppDelegate: UIResponder, UIApplicationDelegate, UNUserNotificationCenterD
// If we've already completed migrations at least once this launch then check
// to see if any "delayed" migrations now need to run
if Storage.shared.hasCompletedMigrations {
SNLog("Checking for pending migrations")
let initialLaunchFailed: Bool = self.initialLaunchFailed
AppReadiness.invalidate()
@ -154,30 +155,33 @@ class AppDelegate: UIResponder, UIApplicationDelegate, UNUserNotificationCenterD
self.window?.rootViewController?.dismiss(animated: false)
}
AppSetup.runPostSetupMigrations(
migrationProgressChanged: { [weak self] progress, minEstimatedTotalTime in
self?.loadingViewController?.updateProgress(
progress: progress,
minEstimatedTotalTime: minEstimatedTotalTime
)
},
migrationsCompletion: { [weak self] result, needsConfigSync in
if case .failure(let error) = result {
DispatchQueue.main.async {
self?.showFailedStartupAlert(
calledFrom: .enterForeground(initialLaunchFailed: initialLaunchFailed),
error: .databaseError(error)
)
// Dispatch async so things can continue to be progressed if a migration does need to run
DispatchQueue.global(qos: .userInitiated).async { [weak self] in
AppSetup.runPostSetupMigrations(
migrationProgressChanged: { progress, minEstimatedTotalTime in
self?.loadingViewController?.updateProgress(
progress: progress,
minEstimatedTotalTime: minEstimatedTotalTime
)
},
migrationsCompletion: { result, needsConfigSync in
if case .failure(let error) = result {
DispatchQueue.main.async {
self?.showFailedStartupAlert(
calledFrom: .enterForeground(initialLaunchFailed: initialLaunchFailed),
error: .databaseError(error)
)
}
return
}
return
self?.completePostMigrationSetup(
calledFrom: .enterForeground(initialLaunchFailed: initialLaunchFailed),
needsConfigSync: needsConfigSync
)
}
self?.completePostMigrationSetup(
calledFrom: .enterForeground(initialLaunchFailed: initialLaunchFailed),
needsConfigSync: needsConfigSync
)
}
)
)
}
}
}
@ -322,8 +326,8 @@ class AppDelegate: UIResponder, UIApplicationDelegate, UNUserNotificationCenterD
// the user is in an invalid state (and should have already been shown a modal)
guard success else { return }
SNLog("RootViewController ready, readying remaining processes")
self?.initialLaunchFailed = false
SNLog("Migrations completed, performing setup and ensuring rootViewController")
/// Trigger any launch-specific jobs and start the JobRunner with `JobRunner.appDidFinishLaunching()` some
/// of these jobs (eg. DisappearingMessages job) can impact the interactions which get fetched to display on the home

@ -35,59 +35,78 @@ public struct SessionApp {
// MARK: - View Convenience Methods
public static func presentConversation(for threadId: String, action: ConversationViewModel.Action = .none, animated: Bool) {
let maybeThreadInfo: (thread: SessionThread, isMessageRequest: Bool)? = Storage.shared.write { db in
let thread: SessionThread = try SessionThread
.fetchOrCreate(db, id: threadId, variant: .contact, shouldBeVisible: nil)
public static func presentConversationCreatingIfNeeded(
for threadId: String,
variant: SessionThread.Variant,
action: ConversationViewModel.Action = .none,
dismissing presentingViewController: UIViewController?,
animated: Bool
) {
let threadInfo: (threadExists: Bool, isMessageRequest: Bool)? = Storage.shared.read { db in
let isMessageRequest: Bool = {
switch variant {
case .contact:
return SessionThread
.isMessageRequest(
id: threadId,
variant: .contact,
currentUserPublicKey: getUserHexEncodedPublicKey(db),
shouldBeVisible: nil,
contactIsApproved: (try? Contact
.filter(id: threadId)
.select(.isApproved)
.asRequest(of: Bool.self)
.fetchOne(db))
.defaulting(to: false),
includeNonVisible: true
)
default: return false
}
}()
return (thread, thread.isMessageRequest(db))
return (SessionThread.filter(id: threadId).isNotEmpty(db), isMessageRequest)
}
guard
let variant: SessionThread.Variant = maybeThreadInfo?.thread.variant,
let isMessageRequest: Bool = maybeThreadInfo?.isMessageRequest
else { return }
// Store the post-creation logic in a closure to avoid duplication
let afterThreadCreated: () -> () = {
presentingViewController?.dismiss(animated: true, completion: nil)
homeViewController.wrappedValue?.show(
threadId,
variant: variant,
isMessageRequest: (threadInfo?.isMessageRequest == true),
with: action,
focusedInteractionInfo: nil,
animated: animated
)
}
self.presentConversation(
for: threadId,
threadVariant: variant,
isMessageRequest: isMessageRequest,
action: action,
focusInteractionInfo: nil,
animated: animated
)
}
public static func presentConversation(
for threadId: String,
threadVariant: SessionThread.Variant,
isMessageRequest: Bool,
action: ConversationViewModel.Action,
focusInteractionInfo: Interaction.TimestampInfo?,
animated: Bool
) {
/// The thread should generally exist at the time of calling this method, but on the off change it doesn't then we need to `fetchOrCreate` it and
/// should do it on a background thread just in case something is keeping the DBWrite thread busy as in the past this could cause the app to hang
guard threadInfo?.threadExists == true else {
DispatchQueue.global(qos: .userInitiated).async {
Storage.shared.write { db in
try SessionThread.fetchOrCreate(db, id: threadId, variant: variant, shouldBeVisible: nil)
}
// Send back to main thread for UI transitions
DispatchQueue.main.async {
afterThreadCreated()
}
}
return
}
// Send to main thread if needed
guard Thread.isMainThread else {
DispatchQueue.main.async {
self.presentConversation(
for: threadId,
threadVariant: threadVariant,
isMessageRequest: isMessageRequest,
action: action,
focusInteractionInfo: focusInteractionInfo,
animated: animated
)
afterThreadCreated()
}
return
}
homeViewController.wrappedValue?.show(
threadId,
variant: threadVariant,
isMessageRequest: isMessageRequest,
with: action,
focusedInteractionInfo: focusInteractionInfo,
animated: animated
)
afterThreadCreated()
}
// MARK: - Functions

@ -37,6 +37,7 @@ enum AppNotificationAction: CaseIterable {
struct AppNotificationUserInfoKey {
static let threadId = "Signal.AppNotificationsUserInfoKey.threadId"
static let threadVariantRaw = "Signal.AppNotificationsUserInfoKey.threadVariantRaw"
static let callBackNumber = "Signal.AppNotificationsUserInfoKey.callBackNumber"
static let localCallId = "Signal.AppNotificationsUserInfoKey.localCallId"
static let threadNotificationCounter = "Session.AppNotificationsUserInfoKey.threadNotificationCounter"
@ -232,8 +233,9 @@ public class NotificationPresenter: NotificationsProtocol {
// "no longer verified".
let category = AppNotificationCategory.incomingMessage
let userInfo = [
AppNotificationUserInfoKey.threadId: thread.id
let userInfo: [AnyHashable: Any] = [
AppNotificationUserInfoKey.threadId: thread.id,
AppNotificationUserInfoKey.threadVariantRaw: thread.variant.rawValue
]
let userPublicKey: String = getUserHexEncodedPublicKey(db)
@ -301,8 +303,9 @@ public class NotificationPresenter: NotificationsProtocol {
let previewType: Preferences.NotificationPreviewType = db[.preferencesNotificationPreviewType]
.defaulting(to: .nameAndPreview)
let userInfo = [
AppNotificationUserInfoKey.threadId: thread.id
let userInfo: [AnyHashable: Any] = [
AppNotificationUserInfoKey.threadId: thread.id,
AppNotificationUserInfoKey.threadVariantRaw: thread.variant.rawValue
]
let notificationTitle: String = "Session"
@ -378,8 +381,9 @@ public class NotificationPresenter: NotificationsProtocol {
let category = AppNotificationCategory.incomingMessage
let userInfo = [
AppNotificationUserInfoKey.threadId: thread.id
let userInfo: [AnyHashable: Any] = [
AppNotificationUserInfoKey.threadId: thread.id,
AppNotificationUserInfoKey.threadVariantRaw: thread.variant.rawValue
]
let threadName: String = SessionThread.displayName(
@ -440,8 +444,9 @@ public class NotificationPresenter: NotificationsProtocol {
let notificationBody = NotificationStrings.failedToSendBody
let userInfo = [
AppNotificationUserInfoKey.threadId: thread.id
let userInfo: [AnyHashable: Any] = [
AppNotificationUserInfoKey.threadId: thread.id,
AppNotificationUserInfoKey.threadVariantRaw: thread.variant.rawValue
]
let fallbackSound: Preferences.Sound = db[.defaultNotificationSound]
.defaulting(to: Preferences.Sound.defaultNotificationSound)
@ -603,15 +608,22 @@ class NotificationActionHandler {
}
func showThread(userInfo: [AnyHashable: Any]) -> AnyPublisher<Void, Never> {
guard let threadId = userInfo[AppNotificationUserInfoKey.threadId] as? String else {
return showHomeVC()
}
guard
let threadId = userInfo[AppNotificationUserInfoKey.threadId] as? String,
let threadVariantRaw = userInfo[AppNotificationUserInfoKey.threadVariantRaw] as? Int,
let threadVariant: SessionThread.Variant = SessionThread.Variant(rawValue: threadVariantRaw)
else { return showHomeVC() }
// If this happens when the the app is not, visible we skip the animation so the thread
// can be visible to the user immediately upon opening the app, rather than having to watch
// it animate in from the homescreen.
let shouldAnimate: Bool = (UIApplication.shared.applicationState == .active)
SessionApp.presentConversation(for: threadId, animated: shouldAnimate)
SessionApp.presentConversationCreatingIfNeeded(
for: threadId,
variant: threadVariant,
dismissing: nil,
animated: (UIApplication.shared.applicationState == .active)
)
return Just(())
.eraseToAnyPublisher()
}

@ -217,12 +217,10 @@ final class JoinOpenGroupVC: BaseVC, UIPageViewControllerDataSource, UIPageViewC
self?.presentingViewController?.dismiss(animated: true, completion: nil)
if shouldOpenCommunity {
SessionApp.presentConversation(
SessionApp.presentConversationCreatingIfNeeded(
for: OpenGroup.idFor(roomToken: roomToken, server: server),
threadVariant: .community,
isMessageRequest: false,
action: .compose,
focusInteractionInfo: nil,
variant: .community,
dismissing: nil,
animated: false
)
}

@ -138,16 +138,12 @@ final class QRCodeVC : BaseVC, UIPageViewControllerDataSource, UIPageViewControl
self.present(modal, animated: true)
}
else {
let maybeThread: SessionThread? = Storage.shared.write { db in
try SessionThread
.fetchOrCreate(db, id: hexEncodedPublicKey, variant: .contact, shouldBeVisible: nil)
}
guard maybeThread != nil else { return }
presentingViewController?.dismiss(animated: true, completion: nil)
SessionApp.presentConversation(for: hexEncodedPublicKey, action: .compose, animated: false)
SessionApp.presentConversationCreatingIfNeeded(
for: hexEncodedPublicKey,
variant: .contact,
dismissing: presentingViewController,
animated: false
)
}
}
}

@ -192,20 +192,6 @@ public extension SessionThread {
)
}
func isMessageRequest(_ db: Database, includeNonVisible: Bool = false) -> Bool {
return (
(includeNonVisible || shouldBeVisible) &&
variant == .contact &&
id != getUserHexEncodedPublicKey(db) && // Note to self
(try? Contact
.filter(id: id)
.select(.isApproved)
.asRequest(of: Bool.self)
.fetchOne(db))
.defaulting(to: false) == false
)
}
static func canSendReadReceipt(
_ db: Database,
threadId: String,
@ -431,6 +417,38 @@ public extension SessionThread {
).sqlExpression
}
func isMessageRequest(_ db: Database, includeNonVisible: Bool = false) -> Bool {
return SessionThread.isMessageRequest(
id: id,
variant: variant,
currentUserPublicKey: getUserHexEncodedPublicKey(db),
shouldBeVisible: shouldBeVisible,
contactIsApproved: (try? Contact
.filter(id: id)
.select(.isApproved)
.asRequest(of: Bool.self)
.fetchOne(db))
.defaulting(to: false),
includeNonVisible: includeNonVisible
)
}
static func isMessageRequest(
id: String,
variant: SessionThread.Variant?,
currentUserPublicKey: String,
shouldBeVisible: Bool?,
contactIsApproved: Bool?,
includeNonVisible: Bool = false
) -> Bool {
return (
(includeNonVisible || shouldBeVisible == true) &&
variant == .contact &&
id != currentUserPublicKey && // Note to self
((contactIsApproved ?? false) == false)
)
}
func isNoteToSelf(_ db: Database? = nil) -> Bool {
return (
variant == .contact &&

@ -521,61 +521,65 @@ public final class OpenGroupManager {
}
}
db.afterNextTransactionNested { db in
// Start the poller if needed
if dependencies.cache.pollers[server.lowercased()] == nil {
dependencies.mutableCache.mutate {
$0.pollers[server.lowercased()]?.stop()
$0.pollers[server.lowercased()] = OpenGroupAPI.Poller(for: server.lowercased())
db.afterNextTransactionNested { _ in
// Dispatch async to the workQueue to prevent holding up the DBWrite thread from the
// above transaction
OpenGroupAPI.workQueue.async {
// Start the poller if needed
if dependencies.cache.pollers[server.lowercased()] == nil {
dependencies.mutableCache.mutate {
$0.pollers[server.lowercased()]?.stop()
$0.pollers[server.lowercased()] = OpenGroupAPI.Poller(for: server.lowercased())
}
dependencies.cache.pollers[server.lowercased()]?.startIfNeeded(using: dependencies)
}
dependencies.cache.pollers[server.lowercased()]?.startIfNeeded(using: dependencies)
}
/// Start downloading the room image (if we don't have one or it's been updated)
if
let imageId: String = (pollInfo.details?.imageId ?? openGroup.imageId),
(
openGroup.imageData == nil ||
openGroup.imageId != imageId
)
{
OpenGroupManager
.roomImage(
fileId: imageId,
for: roomToken,
on: server,
existingData: openGroup.imageData,
using: dependencies
/// Start downloading the room image (if we don't have one or it's been updated)
if
let imageId: String = (pollInfo.details?.imageId ?? openGroup.imageId),
(
openGroup.imageData == nil ||
openGroup.imageId != imageId
)
// Note: We need to subscribe and receive on different threads to ensure the
// logic in 'receiveValue' doesn't result in a reentrancy database issue
.subscribe(on: OpenGroupAPI.workQueue)
.receive(on: DispatchQueue.global(qos: .default))
.sinkUntilComplete(
receiveCompletion: { _ in
if waitForImageToComplete {
completion?()
}
},
receiveValue: { data in
dependencies.storage.write { db in
_ = try OpenGroup
.filter(id: threadId)
.updateAll(db, OpenGroup.Columns.imageData.set(to: data))
{
OpenGroupManager
.roomImage(
fileId: imageId,
for: roomToken,
on: server,
existingData: openGroup.imageData,
using: dependencies
)
// Note: We need to subscribe and receive on different threads to ensure the
// logic in 'receiveValue' doesn't result in a reentrancy database issue
.subscribe(on: OpenGroupAPI.workQueue)
.receive(on: DispatchQueue.global(qos: .default))
.sinkUntilComplete(
receiveCompletion: { _ in
if waitForImageToComplete {
completion?()
}
},
receiveValue: { data in
dependencies.storage.write { db in
_ = try OpenGroup
.filter(id: threadId)
.updateAll(db, OpenGroup.Columns.imageData.set(to: data))
}
}
}
)
}
else if waitForImageToComplete {
)
}
else if waitForImageToComplete {
completion?()
}
// If we want to wait for the image to complete then don't call the completion here
guard !waitForImageToComplete else { return }
// Finish
completion?()
}
// If we want to wait for the image to complete then don't call the completion here
guard !waitForImageToComplete else { return }
// Finish
completion?()
}
}

@ -36,7 +36,7 @@ extension MessageReceiver {
guard
let profilePictureUrl: String = profile.profilePictureUrl,
let profileKey: Data = profile.profileKey
else { return .none }
else { return .remove }
return .updateTo(
url: profilePictureUrl,

@ -14,7 +14,7 @@ public final class ClosedGroupPoller: Poller {
override var namespaces: [SnodeAPI.Namespace] { ClosedGroupPoller.namespaces }
override var maxNodePollCount: UInt { 0 }
private static let minPollInterval: Double = 2
private static let minPollInterval: Double = 3
private static let maxPollInterval: Double = 30
// MARK: - Initialization
@ -78,30 +78,12 @@ public final class ClosedGroupPoller: Poller {
return nextPollInterval
}
override func getSnodeForPolling(
for publicKey: String
) -> AnyPublisher<Snode, Error> {
return SnodeAPI.getSwarm(for: publicKey)
.tryMap { swarm -> Snode in
guard let snode: Snode = swarm.randomElement() else {
throw OnionRequestAPIError.insufficientSnodes
}
return snode
}
.eraseToAnyPublisher()
}
override func handlePollError(
_ error: Error,
for publicKey: String,
using dependencies: SMKDependencies = SMKDependencies()
) {
) -> Bool {
SNLog("Polling failed for closed group with public key: \(publicKey) due to error: \(error).")
// Try to restart the poller from scratch
Threading.pollerQueue.async { [weak self] in
self?.setUpPolling(for: publicKey, using: dependencies)
}
return true
}
}

@ -11,9 +11,6 @@ public final class CurrentUserPoller: Poller {
public static var namespaces: [SnodeAPI.Namespace] = [
.default, .configUserProfile, .configContacts, .configConvoInfoVolatile, .configUserGroups
]
private var targetSnode: Atomic<Snode?> = Atomic(nil)
private var usedSnodes: Atomic<Set<Snode>> = Atomic([])
// MARK: - Settings
@ -63,53 +60,16 @@ public final class CurrentUserPoller: Poller {
return min(maxRetryInterval, nextDelay)
}
override func getSnodeForPolling(
for publicKey: String
) -> AnyPublisher<Snode, Error> {
if let targetSnode: Snode = self.targetSnode.wrappedValue {
return Just(targetSnode)
.setFailureType(to: Error.self)
.eraseToAnyPublisher()
}
// Used the cached swarm for the given key and update the list of unusedSnodes
let swarm: Set<Snode> = (SnodeAPI.swarmCache.wrappedValue[publicKey] ?? [])
let unusedSnodes: Set<Snode> = swarm.subtracting(usedSnodes.wrappedValue)
// randomElement() uses the system's default random generator, which is cryptographically secure
if let nextSnode: Snode = unusedSnodes.randomElement() {
self.targetSnode.mutate { $0 = nextSnode }
self.usedSnodes.mutate { $0.insert(nextSnode) }
return Just(nextSnode)
.setFailureType(to: Error.self)
.eraseToAnyPublisher()
}
// If we haven't retrieved a target snode at this point then either the cache
// is empty or we have used all of the snodes and need to start from scratch
return SnodeAPI.getSwarm(for: publicKey)
.tryFlatMap { [weak self] _ -> AnyPublisher<Snode, Error> in
guard let strongSelf = self else { throw SnodeAPIError.generic }
self?.targetSnode.mutate { $0 = nil }
self?.usedSnodes.mutate { $0.removeAll() }
return strongSelf.getSnodeForPolling(for: publicKey)
}
.eraseToAnyPublisher()
}
override func handlePollError(
_ error: Error,
for publicKey: String,
using dependencies: SMKDependencies = SMKDependencies()
) {
) -> Bool {
if UserDefaults.sharedLokiProject?[.isMainAppActive] != true {
// Do nothing when an error gets throws right after returning from the background (happens frequently)
}
else if let targetSnode: Snode = targetSnode.wrappedValue {
SNLog("Polling \(targetSnode) failed; dropping it and switching to next snode.")
SNLog("Main Poller polling \(targetSnode) failed; dropping it and switching to next snode.")
self.targetSnode.mutate { $0 = nil }
SnodeAPI.dropSnodeFromSwarmIfNeeded(targetSnode, publicKey: publicKey)
}
@ -117,9 +77,6 @@ public final class CurrentUserPoller: Poller {
SNLog("Polling failed due to having no target service node.")
}
// Try to restart the poller from scratch
Threading.pollerQueue.async { [weak self] in
self?.setUpPolling(for: publicKey, using: dependencies)
}
return true
}
}

@ -57,49 +57,42 @@ extension OpenGroupAPI {
) {
guard hasStarted else { return }
dependencies.storage
.readPublisher { [server = server] db in
try OpenGroup
.filter(OpenGroup.Columns.server == server)
.select(min(OpenGroup.Columns.pollFailureCount))
.asRequest(of: TimeInterval.self)
.fetchOne(db)
}
.tryFlatMap { [weak self] minPollFailureCount -> AnyPublisher<(TimeInterval, TimeInterval), Error> in
guard let strongSelf = self else { throw OpenGroupAPIError.invalidPoll }
let lastPollStart: TimeInterval = Date().timeIntervalSince1970
let nextPollInterval: TimeInterval = Poller.getInterval(
for: (minPollFailureCount ?? 0),
minInterval: Poller.minPollInterval,
maxInterval: Poller.maxPollInterval
)
// Wait until the last poll completes before polling again ensuring we don't poll any faster than
// the 'nextPollInterval' value
return strongSelf.poll(using: dependencies)
.map { _ in (lastPollStart, nextPollInterval) }
.eraseToAnyPublisher()
}
let server: String = self.server
let lastPollStart: TimeInterval = Date().timeIntervalSince1970
poll(using: dependencies)
.subscribe(on: dependencies.subscribeQueue)
.receive(on: dependencies.receiveQueue)
.sinkUntilComplete(
receiveValue: { [weak self] lastPollStart, nextPollInterval in
receiveCompletion: { [weak self] _ in
let minPollFailureCount: Int64 = dependencies.storage
.read { db in
try OpenGroup
.filter(OpenGroup.Columns.server == server)
.select(min(OpenGroup.Columns.pollFailureCount))
.asRequest(of: Int64.self)
.fetchOne(db)
}
.defaulting(to: 0)
// Calculate the remaining poll delay
let currentTime: TimeInterval = Date().timeIntervalSince1970
let nextPollInterval: TimeInterval = Poller.getInterval(
for: TimeInterval(minPollFailureCount),
minInterval: Poller.minPollInterval,
maxInterval: Poller.maxPollInterval
)
let remainingInterval: TimeInterval = max(0, nextPollInterval - (currentTime - lastPollStart))
// Schedule the next poll
guard remainingInterval > 0 else {
return dependencies.subscribeQueue.async {
self?.pollRecursively(using: dependencies)
}
}
self?.timer = Timer.scheduledTimerOnMainThread(withTimeInterval: remainingInterval, repeats: false) { timer in
timer.invalidate()
dependencies.subscribeQueue.async {
self?.pollRecursively(using: dependencies)
}
dependencies.subscribeQueue.asyncAfter(deadline: .now() + .milliseconds(Int(remainingInterval * 1000)), qos: .default) {
self?.pollRecursively(using: dependencies)
}
}
)
@ -227,7 +220,7 @@ extension OpenGroupAPI {
.defaulting(to: 0)
var prunedIds: [String] = []
Storage.shared.writeAsync { db in
dependencies.storage.writeAsync { db in
struct Info: Decodable, FetchableRecord {
let id: String
let shouldBeVisible: Bool

@ -8,11 +8,14 @@ import SessionSnodeKit
import SessionUtilitiesKit
public class Poller {
private var timers: Atomic<[String: Timer]> = Atomic([:])
private var cancellables: Atomic<[String: AnyCancellable]> = Atomic([:])
internal var isPolling: Atomic<[String: Bool]> = Atomic([:])
internal var pollCount: Atomic<[String: Int]> = Atomic([:])
internal var failureCount: Atomic<[String: Int]> = Atomic([:])
internal var targetSnode: Atomic<Snode?> = Atomic(nil)
private var usedSnodes: Atomic<Set<Snode>> = Atomic([])
// MARK: - Settings
/// The namespaces which this poller queries
@ -20,7 +23,7 @@ public class Poller {
preconditionFailure("abstract class - override in subclass")
}
/// The number of times the poller can poll before swapping to a new snode
/// The number of times the poller can poll a single snode before swapping to a new snode
internal var maxNodePollCount: UInt {
preconditionFailure("abstract class - override in subclass")
}
@ -39,7 +42,7 @@ public class Poller {
public func stopPolling(for publicKey: String) {
isPolling.mutate { $0[publicKey] = false }
timers.mutate { $0[publicKey]?.invalidate() }
cancellables.mutate { $0[publicKey]?.cancel() }
}
// MARK: - Abstract Methods
@ -49,17 +52,13 @@ public class Poller {
preconditionFailure("abstract class - override in subclass")
}
/// Calculate the delay which should occur before the next poll
internal func nextPollDelay(for publicKey: String) -> TimeInterval {
preconditionFailure("abstract class - override in subclass")
}
internal func getSnodeForPolling(
for publicKey: String
) -> AnyPublisher<Snode, Error> {
preconditionFailure("abstract class - override in subclass")
}
internal func handlePollError(_ error: Error, for publicKey: String, using dependencies: SMKDependencies) {
/// Perform and logic which should occur when the poll errors, will stop polling if `false` is returned
internal func handlePollError(_ error: Error, for publicKey: String, using dependencies: SMKDependencies) -> Bool {
preconditionFailure("abstract class - override in subclass")
}
@ -75,48 +74,65 @@ public class Poller {
// and the timer is not created, if we mark the group as is polling
// after setUpPolling. So the poller may not work, thus misses messages
self?.isPolling.mutate { $0[publicKey] = true }
self?.setUpPolling(for: publicKey)
self?.pollRecursively(for: publicKey)
}
}
/// We want to initially trigger a poll against the target service node and then run the recursive polling,
/// if an error is thrown during the poll then this should automatically restart the polling
internal func setUpPolling(
internal func getSnodeForPolling(
for publicKey: String,
using dependencies: SMKDependencies = SMKDependencies(
subscribeQueue: Threading.pollerQueue,
receiveQueue: Threading.pollerQueue
)
) {
guard isPolling.wrappedValue[publicKey] == true else { return }
using dependencies: SMKDependencies = SMKDependencies()
) -> AnyPublisher<Snode, Error> {
// If we don't want to poll a snode multiple times then just grab a random one from the swarm
guard maxNodePollCount > 0 else {
return SnodeAPI.getSwarm(for: publicKey, using: dependencies)
.tryMap { swarm -> Snode in
try swarm.randomElement() ?? { throw OnionRequestAPIError.insufficientSnodes }()
}
.eraseToAnyPublisher()
}
let namespaces: [SnodeAPI.Namespace] = self.namespaces
// If we already have a target snode then use that
if let targetSnode: Snode = self.targetSnode.wrappedValue {
return Just(targetSnode)
.setFailureType(to: Error.self)
.eraseToAnyPublisher()
}
getSnodeForPolling(for: publicKey)
.flatMap { snode -> AnyPublisher<[Message], Error> in
Poller.poll(
namespaces: namespaces,
from: snode,
for: publicKey,
poller: self,
using: dependencies
)
}
.subscribe(on: dependencies.subscribeQueue)
.receive(on: dependencies.receiveQueue)
.sinkUntilComplete(
receiveCompletion: { [weak self] result in
switch result {
case .finished: self?.pollRecursively(for: publicKey, using: dependencies)
case .failure(let error):
guard self?.isPolling.wrappedValue[publicKey] == true else { return }
self?.handlePollError(error, for: publicKey, using: dependencies)
}
// Select the next unused snode from the swarm (if we've used them all then clear the used list and
// start cycling through them again)
return SnodeAPI.getSwarm(for: publicKey, using: dependencies)
.tryMap { [usedSnodes = self.usedSnodes, targetSnode = self.targetSnode] swarm -> Snode in
let unusedSnodes: Set<Snode> = swarm.subtracting(usedSnodes.wrappedValue)
// If we've used all of the SNodes then clear out the used list
if unusedSnodes.isEmpty {
usedSnodes.mutate { $0.removeAll() }
}
)
// Select the next SNode
let nextSnode: Snode = try swarm.randomElement() ?? { throw OnionRequestAPIError.insufficientSnodes }()
targetSnode.mutate { $0 = nextSnode }
usedSnodes.mutate { $0.insert(nextSnode) }
return nextSnode
}
.eraseToAnyPublisher()
}
internal func incrementPollCount(publicKey: String) {
guard maxNodePollCount > 0 else { return }
let pollCount: Int = (self.pollCount.wrappedValue[publicKey] ?? 0)
self.pollCount.mutate { $0[publicKey] = (pollCount + 1) }
// Check if we've polled the serice node too many times
guard pollCount > maxNodePollCount else { return }
// If we have polled this service node more than the maximum allowed then clear out
// the 'targetServiceNode' value
self.targetSnode.mutate { $0 = nil }
}
private func pollRecursively(
for publicKey: String,
using dependencies: SMKDependencies = SMKDependencies()
@ -124,65 +140,60 @@ public class Poller {
guard isPolling.wrappedValue[publicKey] == true else { return }
let namespaces: [SnodeAPI.Namespace] = self.namespaces
let nextPollInterval: TimeInterval = nextPollDelay(for: publicKey)
let lastPollStart: TimeInterval = Date().timeIntervalSince1970
let lastPollInterval: TimeInterval = nextPollDelay(for: publicKey)
let getSnodePublisher: AnyPublisher<Snode, Error> = getSnodeForPolling(for: publicKey)
timers.mutate {
$0[publicKey] = Timer.scheduledTimerOnMainThread(
withTimeInterval: nextPollInterval,
repeats: false
) { [weak self] timer in
timer.invalidate()
self?.getSnodeForPolling(for: publicKey)
.flatMap { snode -> AnyPublisher<[Message], Error> in
Poller.poll(
namespaces: namespaces,
from: snode,
for: publicKey,
poller: self,
using: dependencies
// Store the publisher intp the cancellables dictionary
cancellables.mutate { [weak self] cancellables in
cancellables[publicKey] = getSnodePublisher
.flatMap { snode -> AnyPublisher<[Message], Error> in
Poller.poll(
namespaces: namespaces,
from: snode,
for: publicKey,
poller: self,
using: dependencies
)
}
.subscribe(on: dependencies.subscribeQueue)
.receive(on: dependencies.receiveQueue)
.sink(
receiveCompletion: { result in
switch result {
case .failure(let error):
// Determine if the error should stop us from polling anymore
guard self?.handlePollError(error, for: publicKey, using: dependencies) == true else {
return
}
case .finished: break
}
// Increment the poll count
self?.incrementPollCount(publicKey: publicKey)
// Calculate the remaining poll delay
let currentTime: TimeInterval = Date().timeIntervalSince1970
let nextPollInterval: TimeInterval = (
self?.nextPollDelay(for: publicKey) ??
lastPollInterval
)
}
.subscribe(on: dependencies.subscribeQueue)
.receive(on: dependencies.receiveQueue)
.sinkUntilComplete(
receiveCompletion: { result in
switch result {
case .failure(let error): self?.handlePollError(error, for: publicKey, using: dependencies)
case .finished:
let maxNodePollCount: UInt = (self?.maxNodePollCount ?? 0)
// If we have polled this service node more than the
// maximum allowed then throw an error so the parent
// loop can restart the polling
if maxNodePollCount > 0 {
let pollCount: Int = (self?.pollCount.wrappedValue[publicKey] ?? 0)
self?.pollCount.mutate { $0[publicKey] = (pollCount + 1) }
guard pollCount < maxNodePollCount else {
let newSnodeNextPollInterval: TimeInterval = (self?.nextPollDelay(for: publicKey) ?? nextPollInterval)
self?.timers.mutate {
$0[publicKey] = Timer.scheduledTimerOnMainThread(
withTimeInterval: newSnodeNextPollInterval,
repeats: false
) { [weak self] timer in
timer.invalidate()
self?.pollCount.mutate { $0[publicKey] = 0 }
self?.setUpPolling(for: publicKey, using: dependencies)
}
}
return
}
}
// Otherwise just loop
self?.pollRecursively(for: publicKey, using: dependencies)
let remainingInterval: TimeInterval = max(0, nextPollInterval - (currentTime - lastPollStart))
// Schedule the next poll
guard remainingInterval > 0 else {
return dependencies.subscribeQueue.async {
self?.pollRecursively(for: publicKey, using: dependencies)
}
}
)
}
dependencies.subscribeQueue.asyncAfter(deadline: .now() + .milliseconds(Int(remainingInterval * 1000)), qos: .default) {
self?.pollRecursively(for: publicKey, using: dependencies)
}
},
receiveValue: { _ in }
)
}
}
@ -199,6 +210,7 @@ public class Poller {
isBackgroundPollValid: @escaping (() -> Bool) = { true },
poller: Poller? = nil,
using dependencies: SMKDependencies = SMKDependencies(
subscribeQueue: Threading.pollerQueue,
receiveQueue: Threading.pollerQueue
)
) -> AnyPublisher<[Message], Error> {

@ -43,7 +43,7 @@ internal extension SessionUtil {
// The current users contact data is handled separately so exclude it if it's present (as that's
// actually a bug)
let userPublicKey: String = getUserHexEncodedPublicKey(db)
let targetContactData: [String: ContactData] = extractContacts(
let targetContactData: [String: ContactData] = try extractContacts(
from: conf,
latestConfigSentTimestampMs: latestConfigSentTimestampMs
).filter { $0.key != userPublicKey }
@ -669,12 +669,15 @@ private extension SessionUtil {
static func extractContacts(
from conf: UnsafeMutablePointer<config_object>?,
latestConfigSentTimestampMs: Int64
) -> [String: ContactData] {
) throws -> [String: ContactData] {
var infiniteLoopGuard: Int = 0
var result: [String: ContactData] = [:]
var contact: contacts_contact = contacts_contact()
let contactIterator: UnsafeMutablePointer<contacts_iterator> = contacts_iterator_new(conf)
while !contacts_iterator_done(contactIterator, &contact) {
try SessionUtil.checkLoopLimitReached(&infiniteLoopGuard, for: .contacts)
let contactId: String = String(cString: withUnsafeBytes(of: contact.session_id) { [UInt8]($0) }
.map { CChar($0) }
.nullTerminated()

@ -23,7 +23,7 @@ internal extension SessionUtil {
guard conf != nil else { throw SessionUtilError.nilConfigObject }
// Get the volatile thread info from the conf and local conversations
let volatileThreadInfo: [VolatileThreadInfo] = extractConvoVolatileInfo(from: conf)
let volatileThreadInfo: [VolatileThreadInfo] = try extractConvoVolatileInfo(from: conf)
let localVolatileThreadInfo: [String: VolatileThreadInfo] = VolatileThreadInfo.fetchAll(db)
.reduce(into: [:]) { result, next in result[next.threadId] = next }
@ -80,7 +80,8 @@ internal extension SessionUtil {
try Interaction
.filter(
Interaction.Columns.threadId == threadId &&
Interaction.Columns.timestampMs <= lastReadTimestampMs
Interaction.Columns.timestampMs <= lastReadTimestampMs &&
Interaction.Columns.wasRead == false
)
.updateAll( // Handling a config update so don't use `updateAllAndConfig`
db,
@ -320,10 +321,7 @@ public extension SessionUtil {
openGroup: OpenGroup?
) -> Bool {
return SessionUtil
.config(
for: .convoInfoVolatile,
publicKey: userPublicKey
)
.config(for: .convoInfoVolatile, publicKey: userPublicKey)
.wrappedValue
.map { conf in
switch threadVariant {
@ -518,7 +516,8 @@ public extension SessionUtil {
internal static func extractConvoVolatileInfo(
from conf: UnsafeMutablePointer<config_object>?
) -> [VolatileThreadInfo] {
) throws -> [VolatileThreadInfo] {
var infiniteLoopGuard: Int = 0
var result: [VolatileThreadInfo] = []
var oneToOne: convo_info_volatile_1to1 = convo_info_volatile_1to1()
var community: convo_info_volatile_community = convo_info_volatile_community()
@ -526,6 +525,8 @@ public extension SessionUtil {
let convoIterator: OpaquePointer = convo_info_volatile_iterator_new(conf)
while !convo_info_volatile_iterator_done(convoIterator) {
try SessionUtil.checkLoopLimitReached(&infiniteLoopGuard, for: .convoInfoVolatile)
if convo_info_volatile_it_is_1to1(convoIterator, &oneToOne) {
result.append(
VolatileThreadInfo(

@ -59,10 +59,7 @@ internal extension SessionUtil {
do {
needsPush = try SessionUtil
.config(
for: variant,
publicKey: publicKey
)
.config(for: variant, publicKey: publicKey)
.mutate { conf in
guard conf != nil else { throw SessionUtilError.nilConfigObject }
@ -332,6 +329,15 @@ internal extension SessionUtil {
// Ensure the change occurred after the last config message was handled (minus the buffer period)
return (changeTimestampMs >= (configDumpTimestampMs - Int64(SessionUtil.configChangeBufferPeriod * 1000)))
}
static func checkLoopLimitReached(_ loopCounter: inout Int, for variant: ConfigDump.Variant, maxLoopCount: Int = 50000) throws {
loopCounter += 1
guard loopCounter < maxLoopCount else {
SNLog("[libSession] Got stuck in infinite loop processing '\(variant.configMessageKind.description)' data")
throw SessionUtilError.processingLoopLimitReached
}
}
}
// MARK: - External Outgoing Changes

@ -34,6 +34,7 @@ internal extension SessionUtil {
guard mergeNeedsDump else { return }
guard conf != nil else { throw SessionUtilError.nilConfigObject }
var infiniteLoopGuard: Int = 0
var communities: [PrioritisedData<OpenGroupUrlInfo>] = []
var legacyGroups: [LegacyGroupInfo] = []
var community: ugroups_community_info = ugroups_community_info()
@ -41,6 +42,8 @@ internal extension SessionUtil {
let groupsIterator: OpaquePointer = user_groups_iterator_new(conf)
while !user_groups_iterator_done(groupsIterator) {
try SessionUtil.checkLoopLimitReached(&infiniteLoopGuard, for: .userGroups)
if user_groups_it_is_community(groupsIterator, &community) {
let server: String = String(libSessionVal: community.base_url)
let roomToken: String = String(libSessionVal: community.room)

@ -314,9 +314,10 @@ public enum SessionUtil {
.compactMap { variant -> OutgoingConfResult? in
try SessionUtil
.config(for: variant, publicKey: publicKey)
.mutate { conf in
.wrappedValue
.map { conf in
// Check if the config needs to be pushed
guard conf != nil && config_needs_push(conf) else { return nil }
guard config_needs_push(conf) else { return nil }
var cPushData: UnsafeMutablePointer<config_push_data>!
let configCountInfo: String = {
@ -375,10 +376,7 @@ public enum SessionUtil {
publicKey: String
) -> ConfigDump? {
return SessionUtil
.config(
for: message.kind.configDumpVariant,
publicKey: publicKey
)
.config(for: message.kind.configDumpVariant, publicKey: publicKey)
.mutate { conf in
guard conf != nil else { return nil }

@ -7,4 +7,5 @@ public enum SessionUtilError: Error {
case nilConfigObject
case userDoesNotExist
case getOrConstructFailedUnexpectedly
case processingLoopLimitReached
}

@ -265,12 +265,15 @@ public struct ProfileManager {
return
}
// Update the cache first (in case the DBWrite thread is blocked, this way other threads
// can retrieve from the cache and avoid triggering a download)
profileAvatarCache.mutate { $0[fileName] = decryptedData }
// Store the updated 'profilePictureFileName'
Storage.shared.write { db in
_ = try? Profile
.filter(id: profile.id)
.updateAll(db, Profile.Columns.profilePictureFileName.set(to: fileName))
profileAvatarCache.mutate { $0[fileName] = decryptedData }
}
}
)

@ -438,14 +438,6 @@ class OpenGroupManagerSpec: QuickSpec {
mockOGMCache.when { $0.isPolling }.thenReturn(true)
mockOGMCache.when { $0.pollers }.thenReturn(["testserver": OpenGroupAPI.Poller(for: "testserver")])
mockUserDefaults
.when { (defaults: inout any UserDefaultsType) -> Any? in
defaults.object(forKey: SNUserDefaults.Date.lastOpen.rawValue)
}
.thenReturn(Date(timeIntervalSince1970: 1234567890))
openGroupManager.startPolling(using: dependencies)
}
it("removes all pollers") {

@ -242,7 +242,7 @@ class ThreadDisappearingMessagesSettingsViewModelSpec: QuickSpec {
var footerButtonInfo: SessionButton.Info?
cancellables.append(
viewModel.rightNavItems
viewModel.footerButtonInfo
.receive(on: ImmediateScheduler.shared)
.sink(
receiveCompletion: { _ in },
@ -275,7 +275,7 @@ class ThreadDisappearingMessagesSettingsViewModelSpec: QuickSpec {
)
cancellables.append(
viewModel.observableTableData
.receiveOnMain(immediately: true)
.receive(on: DispatchQueue.main)
.sink(
receiveCompletion: { _ in },
receiveValue: { viewModel.updateTableData($0.0) }
@ -333,7 +333,7 @@ class ThreadDisappearingMessagesSettingsViewModelSpec: QuickSpec {
cancellables.append(
viewModel.footerButtonInfo
.receiveOnMain(immediately: true)
.receive(on: DispatchQueue.main)
.sink(
receiveCompletion: { _ in },
receiveValue: { info in footerButtonInfo = info }
@ -348,7 +348,7 @@ class ThreadDisappearingMessagesSettingsViewModelSpec: QuickSpec {
beforeEach {
cancellables.append(
viewModel.rightNavItems
viewModel.footerButtonInfo
.receive(on: ImmediateScheduler.shared)
.sink(
receiveCompletion: { _ in },

@ -206,11 +206,6 @@ open class Storage {
}
})
// If we have an unperformed migration then trigger the progress updater immediately
if let firstMigrationKey: String = unperformedMigrations.first?.key {
self.migrationProgressUpdater?.wrappedValue(firstMigrationKey, 0)
}
// Store the logic to run when the migration completes
let migrationCompleted: (Swift.Result<Void, Error>) -> () = { [weak self] result in
self?.migrationsCompleted.mutate { $0 = true }
@ -230,10 +225,17 @@ open class Storage {
onComplete(result, needsConfigSync)
}
// Update the 'migrationsCompleted' state (since we not support running migrations when
// returning from the background it's possible for this flag to transition back to false)
if unperformedMigrations.isEmpty {
self.migrationsCompleted.mutate { $0 = false }
// if there aren't any migrations to run then just complete immediately (this way the migrator
// doesn't try to execute on the DBWrite thread so returning from the background can't get blocked
// due to some weird endless process running)
guard !unperformedMigrations.isEmpty else {
migrationCompleted(.success(()))
return
}
// If we have an unperformed migration then trigger the progress updater immediately
if let firstMigrationKey: String = unperformedMigrations.first?.key {
self.migrationProgressUpdater?.wrappedValue(firstMigrationKey, 0)
}
// Note: The non-async migration should only be used for unit tests
@ -377,16 +379,28 @@ open class Storage {
updates: @escaping (Database) throws -> T
) -> (Database) throws -> T {
return { db in
let start: CFTimeInterval = CACurrentMediaTime()
let fileName: String = (info.file.components(separatedBy: "/").last.map { " \($0):\(info.line)" } ?? "")
let timeout: Timer = Timer.scheduledTimerOnMainThread(withTimeInterval: writeWarningThreadshold) {
$0.invalidate()
// Don't want to log on the main thread as to avoid confusion when debugging issues
DispatchQueue.global(qos: .default).async {
let fileName: String = (info.file.components(separatedBy: "/").last.map { " \($0):\(info.line)" } ?? "")
SNLog("[Storage\(fileName)] Slow write taking longer than \(writeWarningThreadshold)s - \(info.function)")
SNLog("[Storage\(fileName)] Slow write taking longer than \(writeWarningThreadshold, format: ".2", omitZeroDecimal: true)s - \(info.function)")
}
}
defer {
// If we timed out then log the actual duration to help us prioritise performance issues
if !timeout.isValid {
let end: CFTimeInterval = CACurrentMediaTime()
DispatchQueue.global(qos: .default).async {
SNLog("[Storage\(fileName)] Slow write completed after \(end - start, format: ".2", omitZeroDecimal: true)s")
}
}
timeout.invalidate()
}
defer { timeout.invalidate() }
return try updates(db)
}

@ -10,6 +10,12 @@ import DifferenceKit
///
/// **Note:** We **MUST** have accurate `filterSQL` and `orderSQL` values otherwise the indexing won't work
public class PagedDatabaseObserver<ObservedTable, T>: TransactionObserver where ObservedTable: TableRecord & ColumnExpressible & Identifiable, T: FetchableRecordWithRowId & Identifiable {
private let commitProcessingQueue: DispatchQueue = DispatchQueue(
label: "PagedDatabaseObserver.commitProcessingQueue",
qos: .userInitiated,
attributes: [] // Must be serial in order to avoid updates getting processed in the wrong order
)
// MARK: - Variables
private let pagedTableName: String
@ -145,74 +151,58 @@ public class PagedDatabaseObserver<ObservedTable, T>: TransactionObserver where
changesInCommit.mutate { $0.insert(trackedChange) }
}
// Note: We will process all updates which come through this method even if
// 'onChange' is null because if the UI stops observing and then starts again
// later we don't want to have missed any changes which happened while the UI
// wasn't subscribed (and doing a full re-query seems painful...)
/// We will process all updates which come through this method even if 'onChange' is null because if the UI stops observing and then starts
/// again later we don't want to have missed any changes which happened while the UI wasn't subscribed (and doing a full re-query seems painful...)
///
/// **Note:** This function is generally called within the DBWrite thread but we don't actually need write access to process the commit, in order
/// to avoid blocking the DBWrite thread we dispatch to a serial `commitProcessingQueue` to process the incoming changes (in the past not doing
/// so was resulting in hanging when there was a lot of activity happening)
public func databaseDidCommit(_ db: Database) {
// If there were no pending changes in the commit then do nothing
guard !self.changesInCommit.wrappedValue.isEmpty else { return }
// Since we can't be sure the behaviours of 'databaseDidChange' and 'databaseDidCommit' won't change in
// the future we extract and clear the values in 'changesInCommit' since it's 'Atomic<T>' so will different
// threads modifying the data resulting in us missing a change
var committedChanges: Set<PagedData.TrackedChange> = []
self.changesInCommit.mutate { cachedChanges in
committedChanges = cachedChanges
cachedChanges.removeAll()
}
// Note: This method will be called regardless of whether there were actually changes
// in the areas we are observing so we want to early-out if there aren't any relevant
// updated rows
guard !committedChanges.isEmpty else { return }
commitProcessingQueue.async { [weak self] in
self?.processDatabaseCommit(committedChanges: committedChanges)
}
}
private func processDatabaseCommit(committedChanges: Set<PagedData.TrackedChange>) {
typealias AssociatedDataInfo = [(hasChanges: Bool, data: ErasedAssociatedRecord)]
typealias UpdatedData = (cache: DataCache<T>, pageInfo: PagedData.PageInfo, hasChanges: Bool, associatedData: AssociatedDataInfo)
// Store the instance variables locally to avoid unwrapping
let dataCache: DataCache<T> = self.dataCache.wrappedValue
let pageInfo: PagedData.PageInfo = self.pageInfo.wrappedValue
let joinSQL: SQL? = self.joinSQL
let orderSQL: SQL = self.orderSQL
let filterSQL: SQL = self.filterSQL
let associatedRecords: [ErasedAssociatedRecord] = self.associatedRecords
let updateDataAndCallbackIfNeeded: (DataCache<T>, PagedData.PageInfo, Bool) -> () = { [weak self] updatedDataCache, updatedPageInfo, cacheHasChanges in
let associatedDataInfo: [(hasChanges: Bool, data: ErasedAssociatedRecord)] = associatedRecords
.map { associatedRecord in
let hasChanges: Bool = associatedRecord.tryUpdateForDatabaseCommit(
db,
changes: committedChanges,
joinSQL: joinSQL,
orderSQL: orderSQL,
filterSQL: filterSQL,
pageInfo: updatedPageInfo
)
return (hasChanges, associatedRecord)
}
// Check if we need to trigger a change callback
guard cacheHasChanges || associatedDataInfo.contains(where: { hasChanges, _ in hasChanges }) else {
return
}
// If the associated data changed then update the updatedCachedData with the
// updated associated data
var finalUpdatedDataCache: DataCache<T> = updatedDataCache
associatedDataInfo.forEach { hasChanges, associatedData in
guard cacheHasChanges || hasChanges else { return }
let getAssociatedDataInfo: (Database, PagedData.PageInfo) -> AssociatedDataInfo = { db, updatedPageInfo in
associatedRecords.map { associatedRecord in
let hasChanges: Bool = associatedRecord.tryUpdateForDatabaseCommit(
db,
changes: committedChanges,
joinSQL: joinSQL,
orderSQL: orderSQL,
filterSQL: filterSQL,
pageInfo: updatedPageInfo
)
finalUpdatedDataCache = associatedData.updateAssociatedData(to: finalUpdatedDataCache)
return (hasChanges, associatedRecord)
}
// Update the cache, pageInfo and the change callback
self?.dataCache.mutate { $0 = finalUpdatedDataCache }
self?.pageInfo.mutate { $0 = updatedPageInfo }
// Make sure the updates run on the main thread
guard Thread.isMainThread else {
DispatchQueue.main.async { [weak self] in
self?.onChangeUnsorted(finalUpdatedDataCache.values, updatedPageInfo)
}
return
}
self?.onChangeUnsorted(finalUpdatedDataCache.values, updatedPageInfo)
}
// Determing if there were any direct or related data changes
// Determine if there were any direct or related data changes
let directChanges: Set<PagedData.TrackedChange> = committedChanges
.filter { $0.tableName == pagedTableName }
let relatedChanges: [String: [PagedData.TrackedChange]] = committedChanges
@ -227,215 +217,248 @@ public class PagedDatabaseObserver<ObservedTable, T>: TransactionObserver where
.filter { $0.tableName != pagedTableName }
.filter { $0.kind == .delete }
guard !directChanges.isEmpty || !relatedChanges.isEmpty || !relatedDeletions.isEmpty else {
updateDataAndCallbackIfNeeded(self.dataCache.wrappedValue, self.pageInfo.wrappedValue, false)
return
}
var updatedPageInfo: PagedData.PageInfo = self.pageInfo.wrappedValue
var updatedDataCache: DataCache<T> = self.dataCache.wrappedValue
let deletionChanges: [Int64] = directChanges
.filter { $0.kind == .delete }
.map { $0.rowId }
let oldDataCount: Int = dataCache.wrappedValue.count
// First remove any items which have been deleted
if !deletionChanges.isEmpty {
updatedDataCache = updatedDataCache.deleting(rowIds: deletionChanges)
// Make sure there were actually changes
if updatedDataCache.count != oldDataCount {
let dataSizeDiff: Int = (updatedDataCache.count - oldDataCount)
// Process and retrieve the updated data
let updatedData: UpdatedData = Storage.shared
.read { db -> UpdatedData in
// If there aren't any direct or related changes then early-out
guard !directChanges.isEmpty || !relatedChanges.isEmpty || !relatedDeletions.isEmpty else {
return (dataCache, pageInfo, false, getAssociatedDataInfo(db, pageInfo))
}
updatedPageInfo = PagedData.PageInfo(
pageSize: updatedPageInfo.pageSize,
pageOffset: updatedPageInfo.pageOffset,
currentCount: (updatedPageInfo.currentCount + dataSizeDiff),
totalCount: (updatedPageInfo.totalCount + dataSizeDiff)
)
}
}
// If there are no inserted/updated rows then trigger the update callback and stop here
let changesToQuery: [PagedData.TrackedChange] = directChanges
.filter { $0.kind != .delete }
guard !changesToQuery.isEmpty || !relatedChanges.isEmpty || !relatedDeletions.isEmpty else {
updateDataAndCallbackIfNeeded(updatedDataCache, updatedPageInfo, !deletionChanges.isEmpty)
return
}
// First we need to get the rowIds for the paged data connected to any of the related changes
let pagedRowIdsForRelatedChanges: Set<Int64> = {
guard !relatedChanges.isEmpty else { return [] }
return relatedChanges
.reduce(into: []) { result, next in
guard
let observedChange: PagedData.ObservedChanges = observedTableChangeTypes[next.key],
let joinToPagedType: SQL = observedChange.joinToPagedType
else { return }
// Store a mutable copies of the dataCache and pageInfo for updating
var updatedDataCache: DataCache<T> = dataCache
var updatedPageInfo: PagedData.PageInfo = pageInfo
let deletionChanges: [Int64] = directChanges
.filter { $0.kind == .delete }
.map { $0.rowId }
let oldDataCount: Int = dataCache.count
// First remove any items which have been deleted
if !deletionChanges.isEmpty {
updatedDataCache = updatedDataCache.deleting(rowIds: deletionChanges)
let pagedRowIds: [Int64] = PagedData.pagedRowIdsForRelatedRowIds(
db,
tableName: next.key,
pagedTableName: pagedTableName,
relatedRowIds: Array(next.value.map { $0.rowId }.asSet()),
joinToPagedType: joinToPagedType
)
// Make sure there were actually changes
if updatedDataCache.count != oldDataCount {
let dataSizeDiff: Int = (updatedDataCache.count - oldDataCount)
updatedPageInfo = PagedData.PageInfo(
pageSize: updatedPageInfo.pageSize,
pageOffset: updatedPageInfo.pageOffset,
currentCount: (updatedPageInfo.currentCount + dataSizeDiff),
totalCount: (updatedPageInfo.totalCount + dataSizeDiff)
)
}
}
// If there are no inserted/updated rows then trigger then early-out
let changesToQuery: [PagedData.TrackedChange] = directChanges
.filter { $0.kind != .delete }
guard !changesToQuery.isEmpty || !relatedChanges.isEmpty || !relatedDeletions.isEmpty else {
let associatedData: AssociatedDataInfo = getAssociatedDataInfo(db, updatedPageInfo)
return (updatedDataCache, updatedPageInfo, !deletionChanges.isEmpty, associatedData)
}
// Next we need to determine if any related changes were associated to the pagedData we are
// observing, if they aren't (and there were no other direct changes) we can early-out
let pagedRowIdsForRelatedChanges: Set<Int64> = {
guard !relatedChanges.isEmpty else { return [] }
result.append(contentsOf: pagedRowIds)
return relatedChanges
.reduce(into: []) { result, next in
guard
let observedChange: PagedData.ObservedChanges = observedTableChangeTypes[next.key],
let joinToPagedType: SQL = observedChange.joinToPagedType
else { return }
let pagedRowIds: [Int64] = PagedData.pagedRowIdsForRelatedRowIds(
db,
tableName: next.key,
pagedTableName: pagedTableName,
relatedRowIds: Array(next.value.map { $0.rowId }.asSet()),
joinToPagedType: joinToPagedType
)
result.append(contentsOf: pagedRowIds)
}
.asSet()
}()
guard !changesToQuery.isEmpty || !pagedRowIdsForRelatedChanges.isEmpty || !relatedDeletions.isEmpty else {
let associatedData: AssociatedDataInfo = getAssociatedDataInfo(db, updatedPageInfo)
return (updatedDataCache, updatedPageInfo, !deletionChanges.isEmpty, associatedData)
}
.asSet()
}()
guard !changesToQuery.isEmpty || !pagedRowIdsForRelatedChanges.isEmpty || !relatedDeletions.isEmpty else {
updateDataAndCallbackIfNeeded(updatedDataCache, updatedPageInfo, !deletionChanges.isEmpty)
return
}
// Fetch the indexes of the rowIds so we can determine whether they should be added to the screen
let directRowIds: Set<Int64> = changesToQuery.map { $0.rowId }.asSet()
let pagedRowIdsForRelatedDeletions: Set<Int64> = relatedDeletions
.compactMap { $0.pagedRowIdsForRelatedDeletion }
.flatMap { $0 }
.asSet()
let itemIndexes: [PagedData.RowIndexInfo] = PagedData.indexes(
db,
rowIds: Array(directRowIds),
tableName: pagedTableName,
requiredJoinSQL: joinSQL,
orderSQL: orderSQL,
filterSQL: filterSQL
)
let relatedChangeIndexes: [PagedData.RowIndexInfo] = PagedData.indexes(
db,
rowIds: Array(pagedRowIdsForRelatedChanges),
tableName: pagedTableName,
requiredJoinSQL: joinSQL,
orderSQL: orderSQL,
filterSQL: filterSQL
)
let relatedDeletionIndexes: [PagedData.RowIndexInfo] = PagedData.indexes(
db,
rowIds: Array(pagedRowIdsForRelatedDeletions),
tableName: pagedTableName,
requiredJoinSQL: joinSQL,
orderSQL: orderSQL,
filterSQL: filterSQL
)
// Determine if the indexes for the row ids should be displayed on the screen and remove any
// which shouldn't - values less than 'currentCount' or if there is at least one value less than
// 'currentCount' and the indexes are sequential (ie. more than the current loaded content was
// added at once)
func determineValidChanges(for indexInfo: [PagedData.RowIndexInfo]) -> [Int64] {
let indexes: [Int64] = Array(indexInfo
.map { $0.rowIndex }
.sorted()
.asSet())
let indexesAreSequential: Bool = (indexes.map { $0 - 1 }.dropFirst() == indexes.dropLast())
let hasOneValidIndex: Bool = indexInfo.contains(where: { info -> Bool in
info.rowIndex >= updatedPageInfo.pageOffset && (
info.rowIndex < updatedPageInfo.currentCount || (
updatedPageInfo.currentCount < updatedPageInfo.pageSize &&
info.rowIndex <= (updatedPageInfo.pageOffset + updatedPageInfo.pageSize)
)
// Fetch the indexes of the rowIds so we can determine whether they should be added to the screen
let directRowIds: Set<Int64> = changesToQuery.map { $0.rowId }.asSet()
let pagedRowIdsForRelatedDeletions: Set<Int64> = relatedDeletions
.compactMap { $0.pagedRowIdsForRelatedDeletion }
.flatMap { $0 }
.asSet()
let itemIndexes: [PagedData.RowIndexInfo] = PagedData.indexes(
db,
rowIds: Array(directRowIds),
tableName: pagedTableName,
requiredJoinSQL: joinSQL,
orderSQL: orderSQL,
filterSQL: filterSQL
)
})
return (indexesAreSequential && hasOneValidIndex ?
indexInfo.map { $0.rowId } :
indexInfo
.filter { info -> Bool in
let relatedChangeIndexes: [PagedData.RowIndexInfo] = PagedData.indexes(
db,
rowIds: Array(pagedRowIdsForRelatedChanges),
tableName: pagedTableName,
requiredJoinSQL: joinSQL,
orderSQL: orderSQL,
filterSQL: filterSQL
)
let relatedDeletionIndexes: [PagedData.RowIndexInfo] = PagedData.indexes(
db,
rowIds: Array(pagedRowIdsForRelatedDeletions),
tableName: pagedTableName,
requiredJoinSQL: joinSQL,
orderSQL: orderSQL,
filterSQL: filterSQL
)
// Determine if the indexes for the row ids should be displayed on the screen and remove any
// which shouldn't - values less than 'currentCount' or if there is at least one value less than
// 'currentCount' and the indexes are sequential (ie. more than the current loaded content was
// added at once)
func determineValidChanges(for indexInfo: [PagedData.RowIndexInfo]) -> [Int64] {
let indexes: [Int64] = Array(indexInfo
.map { $0.rowIndex }
.sorted()
.asSet())
let indexesAreSequential: Bool = (indexes.map { $0 - 1 }.dropFirst() == indexes.dropLast())
let hasOneValidIndex: Bool = indexInfo.contains(where: { info -> Bool in
info.rowIndex >= updatedPageInfo.pageOffset && (
info.rowIndex < updatedPageInfo.currentCount || (
updatedPageInfo.currentCount < updatedPageInfo.pageSize &&
info.rowIndex <= (updatedPageInfo.pageOffset + updatedPageInfo.pageSize)
)
)
})
return (indexesAreSequential && hasOneValidIndex ?
indexInfo.map { $0.rowId } :
indexInfo
.filter { info -> Bool in
info.rowIndex >= updatedPageInfo.pageOffset && (
info.rowIndex < updatedPageInfo.currentCount || (
updatedPageInfo.currentCount < updatedPageInfo.pageSize &&
info.rowIndex <= (updatedPageInfo.pageOffset + updatedPageInfo.pageSize)
)
)
}
.map { info -> Int64 in info.rowId }
)
}
let validChangeRowIds: [Int64] = determineValidChanges(for: itemIndexes)
let validRelatedChangeRowIds: [Int64] = determineValidChanges(for: relatedChangeIndexes)
let validRelatedDeletionRowIds: [Int64] = determineValidChanges(for: relatedDeletionIndexes)
let countBefore: Int = itemIndexes.filter { $0.rowIndex < updatedPageInfo.pageOffset }.count
// If the number of indexes doesn't match the number of rowIds then it means something changed
// resulting in an item being filtered out
func performRemovalsIfNeeded(for rowIds: Set<Int64>, indexes: [PagedData.RowIndexInfo]) {
let uniqueIndexes: Set<Int64> = indexes.map { $0.rowId }.asSet()
// If they have the same count then nothin was filtered out so do nothing
guard rowIds.count != uniqueIndexes.count else { return }
// Otherwise something was probably removed so try to remove it from the cache
let rowIdsRemoved: Set<Int64> = rowIds.subtracting(uniqueIndexes)
let preDeletionCount: Int = updatedDataCache.count
updatedDataCache = updatedDataCache.deleting(rowIds: Array(rowIdsRemoved))
// Lastly make sure there were actually changes before updating the page info
guard updatedDataCache.count != preDeletionCount else { return }
let dataSizeDiff: Int = (updatedDataCache.count - preDeletionCount)
updatedPageInfo = PagedData.PageInfo(
pageSize: updatedPageInfo.pageSize,
pageOffset: updatedPageInfo.pageOffset,
currentCount: (updatedPageInfo.currentCount + dataSizeDiff),
totalCount: (updatedPageInfo.totalCount + dataSizeDiff)
)
}
// Actually perform any required removals
performRemovalsIfNeeded(for: directRowIds, indexes: itemIndexes)
performRemovalsIfNeeded(for: pagedRowIdsForRelatedChanges, indexes: relatedChangeIndexes)
performRemovalsIfNeeded(for: pagedRowIdsForRelatedDeletions, indexes: relatedDeletionIndexes)
// Update the offset and totalCount even if the rows are outside of the current page (need to
// in order to ensure the 'load more' sections are accurate)
updatedPageInfo = PagedData.PageInfo(
pageSize: updatedPageInfo.pageSize,
pageOffset: (updatedPageInfo.pageOffset + countBefore),
currentCount: updatedPageInfo.currentCount,
totalCount: (
updatedPageInfo.totalCount +
changesToQuery
.filter { $0.kind == .insert }
.filter { validChangeRowIds.contains($0.rowId) }
.count
)
)
// If there are no valid row ids then early-out (at this point the pageInfo would have changed
// so we want to flat 'hasChanges' as true)
guard !validChangeRowIds.isEmpty || !validRelatedChangeRowIds.isEmpty || !validRelatedDeletionRowIds.isEmpty else {
let associatedData: AssociatedDataInfo = getAssociatedDataInfo(db, updatedPageInfo)
return (updatedDataCache, updatedPageInfo, true, associatedData)
}
// Fetch the inserted/updated rows
let targetRowIds: [Int64] = Array((validChangeRowIds + validRelatedChangeRowIds + validRelatedDeletionRowIds).asSet())
let updatedItems: [T] = {
do { return try dataQuery(targetRowIds).fetchAll(db) }
catch {
SNLog("[PagedDatabaseObserver] Error fetching data during change: \(error)")
return []
}
.map { info -> Int64 in info.rowId }
)
}()
updatedDataCache = updatedDataCache.upserting(items: updatedItems)
// Update the currentCount for the upserted data
let dataSizeDiff: Int = (updatedDataCache.count - oldDataCount)
updatedPageInfo = PagedData.PageInfo(
pageSize: updatedPageInfo.pageSize,
pageOffset: updatedPageInfo.pageOffset,
currentCount: (updatedPageInfo.currentCount + dataSizeDiff),
totalCount: updatedPageInfo.totalCount
)
// Return the final updated data
let associatedData: AssociatedDataInfo = getAssociatedDataInfo(db, updatedPageInfo)
return (updatedDataCache, updatedPageInfo, true, associatedData)
}
.defaulting(to: (cache: dataCache, pageInfo: pageInfo, hasChanges: false, associatedData: []))
// Now that we have all of the changes, check if there were actually any changes
guard updatedData.hasChanges || updatedData.associatedData.contains(where: { hasChanges, _ in hasChanges }) else {
return
}
let validChangeRowIds: [Int64] = determineValidChanges(for: itemIndexes)
let validRelatedChangeRowIds: [Int64] = determineValidChanges(for: relatedChangeIndexes)
let validRelatedDeletionRowIds: [Int64] = determineValidChanges(for: relatedDeletionIndexes)
let countBefore: Int = itemIndexes.filter { $0.rowIndex < updatedPageInfo.pageOffset }.count
// If the number of indexes doesn't match the number of rowIds then it means something changed
// resulting in an item being filtered out
func performRemovalsIfNeeded(for rowIds: Set<Int64>, indexes: [PagedData.RowIndexInfo]) {
let uniqueIndexes: Set<Int64> = indexes.map { $0.rowId }.asSet()
// If they have the same count then nothin was filtered out so do nothing
guard rowIds.count != uniqueIndexes.count else { return }
// Otherwise something was probably removed so try to remove it from the cache
let rowIdsRemoved: Set<Int64> = rowIds.subtracting(uniqueIndexes)
let preDeletionCount: Int = updatedDataCache.count
updatedDataCache = updatedDataCache.deleting(rowIds: Array(rowIdsRemoved))
// If the associated data changed then update the updatedCachedData with the updated associated data
var finalUpdatedDataCache: DataCache<T> = updatedData.cache
// Lastly make sure there were actually changes before updating the page info
guard updatedDataCache.count != preDeletionCount else { return }
let dataSizeDiff: Int = (updatedDataCache.count - preDeletionCount)
updatedData.associatedData.forEach { hasChanges, associatedData in
guard updatedData.hasChanges || hasChanges else { return }
updatedPageInfo = PagedData.PageInfo(
pageSize: updatedPageInfo.pageSize,
pageOffset: updatedPageInfo.pageOffset,
currentCount: (updatedPageInfo.currentCount + dataSizeDiff),
totalCount: (updatedPageInfo.totalCount + dataSizeDiff)
)
finalUpdatedDataCache = associatedData.updateAssociatedData(to: finalUpdatedDataCache)
}
// Actually perform any required removals
performRemovalsIfNeeded(for: directRowIds, indexes: itemIndexes)
performRemovalsIfNeeded(for: pagedRowIdsForRelatedChanges, indexes: relatedChangeIndexes)
performRemovalsIfNeeded(for: pagedRowIdsForRelatedDeletions, indexes: relatedDeletionIndexes)
// Update the offset and totalCount even if the rows are outside of the current page (need to
// in order to ensure the 'load more' sections are accurate)
updatedPageInfo = PagedData.PageInfo(
pageSize: updatedPageInfo.pageSize,
pageOffset: (updatedPageInfo.pageOffset + countBefore),
currentCount: updatedPageInfo.currentCount,
totalCount: (
updatedPageInfo.totalCount +
changesToQuery
.filter { $0.kind == .insert }
.filter { validChangeRowIds.contains($0.rowId) }
.count
)
)
// If there are no valid row ids then stop here (trigger updates though since the page info
// has changes)
guard !validChangeRowIds.isEmpty || !validRelatedChangeRowIds.isEmpty || !validRelatedDeletionRowIds.isEmpty else {
updateDataAndCallbackIfNeeded(updatedDataCache, updatedPageInfo, true)
return
}
// Fetch the inserted/updated rows
let targetRowIds: [Int64] = Array((validChangeRowIds + validRelatedChangeRowIds + validRelatedDeletionRowIds).asSet())
let updatedItems: [T] = (try? dataQuery(targetRowIds)
.fetchAll(db))
.defaulting(to: [])
// Update the cache, pageInfo and the change callback
self.dataCache.mutate { $0 = finalUpdatedDataCache }
self.pageInfo.mutate { $0 = updatedData.pageInfo }
// Process the upserted data
updatedDataCache = updatedDataCache.upserting(items: updatedItems)
// Update the currentCount for the upserted data
let dataSizeDiff: Int = (updatedDataCache.count - oldDataCount)
updatedPageInfo = PagedData.PageInfo(
pageSize: updatedPageInfo.pageSize,
pageOffset: updatedPageInfo.pageOffset,
currentCount: (updatedPageInfo.currentCount + dataSizeDiff),
totalCount: updatedPageInfo.totalCount
)
updateDataAndCallbackIfNeeded(updatedDataCache, updatedPageInfo, true)
// Trigger the unsorted change callback (the actual UI update triggering should eventually be run on
// the main thread via the `PagedData.processAndTriggerUpdates` function)
self.onChangeUnsorted(finalUpdatedDataCache.values, updatedData.pageInfo)
}
public func databaseDidRollback(_ db: Database) {}

@ -77,6 +77,23 @@ public extension String {
// MARK: - Formatting
extension String.StringInterpolation {
mutating func appendInterpolation(_ value: Int, format: String) {
let result: String = String(format: "%\(format)d", value)
appendLiteral(result)
}
mutating func appendInterpolation(_ value: Double, format: String, omitZeroDecimal: Bool = false) {
guard !omitZeroDecimal || Int(exactly: value) == nil else {
appendLiteral("\(Int(exactly: value)!)")
return
}
let result: String = String(format: "%\(format)f", value)
appendLiteral(result)
}
}
public extension String {
static func formattedDuration(_ duration: TimeInterval, format: TimeInterval.DurationFormat = .short) -> String {
let secondsPerMinute: TimeInterval = 60

@ -19,9 +19,12 @@ class SynchronousStorage: Storage {
}
override func writePublisher<T>(
fileName: String = #file,
functionName: String = #function,
lineNumber: Int = #line,
updates: @escaping (Database) throws -> T
) -> AnyPublisher<T, Error> {
guard let result: T = super.write(updates: updates) else {
guard let result: T = super.write(fileName: fileName, functionName: functionName, lineNumber: lineNumber, updates: updates) else {
return Fail(error: StorageError.generic)
.eraseToAnyPublisher()
}

Loading…
Cancel
Save