Refactor roomserver/internal - split perform stuff out (#1380)

- New package `perform` which contains all `Perform` functions
- New package `helpers` which contains helper functions used by both
  perform and query/input functions.
- Perform invite/leave have no idea how to `WriteOutputEvents` and this
  is now returned from `PerformInvite` or `PerformLeave` respectively.

Still to do:
 - RSAPI is fed into the inviter/joiner/leaver - this introduces circular
   logic so will need to be removed.
 - Put query operations in a `query` package.
 - Put input operations (and output) in an `input` package.
 - Factor out helper functions as much as possible, possibly rejigging the
   storage layer in the process.
This commit is contained in:
Kegsay 2020-09-02 13:47:31 +01:00 committed by GitHub
parent 02a73f29f8
commit e473320e73
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 820 additions and 647 deletions

View file

@ -20,11 +20,9 @@ import (
"context"
"fmt"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrixserverlib"
@ -74,7 +72,7 @@ func (r *RoomserverInternalAPI) QueryLatestEventsAndState(
return err
}
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
if err != nil {
return err
}
@ -123,7 +121,7 @@ func (r *RoomserverInternalAPI) QueryStateAfterEvents(
return err
}
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, stateEntries)
if err != nil {
return err
}
@ -151,7 +149,7 @@ func (r *RoomserverInternalAPI) QueryEventsByID(
eventNIDs = append(eventNIDs, nid)
}
events, err := r.loadEvents(ctx, eventNIDs)
events, err := helpers.LoadEvents(ctx, r.DB, eventNIDs)
if err != nil {
return err
}
@ -168,31 +166,6 @@ func (r *RoomserverInternalAPI) QueryEventsByID(
return nil
}
func (r *RoomserverInternalAPI) loadStateEvents(
ctx context.Context, stateEntries []types.StateEntry,
) ([]gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID
}
return r.loadEvents(ctx, eventNIDs)
}
func (r *RoomserverInternalAPI) loadEvents(
ctx context.Context, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.Event, error) {
stateEvents, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
result := make([]gomatrixserverlib.Event, len(stateEvents))
for i := range stateEvents {
result[i] = stateEvents[i].Event
}
return result, nil
}
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryMembershipForUser(
ctx context.Context,
@ -266,12 +239,12 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
events, err = r.DB.Events(ctx, eventNIDs)
} else {
stateEntries, err = stateBeforeEvent(ctx, r.DB, *info, membershipEventNID)
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, *info, membershipEventNID)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
}
events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
events, err = helpers.GetMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
}
if err != nil {
@ -286,65 +259,6 @@ func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
return nil
}
func stateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil {
return nil, err
}
eventIDs := []string{eIDs[eventNID]}
prevState, err := db.StateAtEventIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
// Fetch the state as it was when this event was fired
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
// getMembershipsAtState filters the state events to
// only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events.
func getMembershipsAtState(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) {
var eventNIDs []types.EventNID
for _, entry := range stateEntries {
// Filter the events to retrieve to only keep the membership events
if entry.EventTypeNID == types.MRoomMemberNID {
eventNIDs = append(eventNIDs, entry.EventNID)
}
}
// Get all of the events in this state
stateEvents, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
if !joinedOnly {
return stateEvents, nil
}
// Filter the events to only keep the "join" membership events
var events []types.Event
for _, event := range stateEvents {
membership, err := event.Membership()
if err != nil {
return nil, err
}
if membership == gomatrixserverlib.Join {
events = append(events, event)
}
}
return events, nil
}
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
ctx context.Context,
@ -360,7 +274,7 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
return
}
roomID := events[0].RoomID()
isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, roomID)
isServerInRoom, err := helpers.IsServerCurrentlyInRoom(ctx, r.DB, request.ServerName, roomID)
if err != nil {
return
}
@ -371,31 +285,12 @@ func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
if info == nil {
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
}
response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent(
ctx, *info, request.EventID, request.ServerName, isServerInRoom,
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, *info, request.EventID, request.ServerName, isServerInRoom,
)
return
}
func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent(
ctx context.Context, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) {
roomState := state.NewStateResolution(r.DB, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
return false, err
}
// TODO: We probably want to make it so that we don't have to pull
// out all the state if possible.
stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
if err != nil {
return false, err
}
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
}
// QueryMissingEvents implements api.RoomserverInternalAPI
// nolint:gocyclo
func (r *RoomserverInternalAPI) QueryMissingEvents(
@ -431,12 +326,12 @@ func (r *RoomserverInternalAPI) QueryMissingEvents(
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
}
resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
resultNIDs, err := helpers.ScanEventTree(ctx, r.DB, *info, front, visited, request.Limit, request.ServerName)
if err != nil {
return err
}
loadedEvents, err := r.loadEvents(ctx, resultNIDs)
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, resultNIDs)
if err != nil {
return err
}
@ -456,299 +351,6 @@ func (r *RoomserverInternalAPI) QueryMissingEvents(
return err
}
// PerformBackfill implements api.RoomServerQueryAPI
func (r *RoomserverInternalAPI) PerformBackfill(
ctx context.Context,
request *api.PerformBackfillRequest,
response *api.PerformBackfillResponse,
) error {
// if we are requesting the backfill then we need to do a federation hit
// TODO: we could be more sensible and fetch as many events we already have then request the rest
// which is what the syncapi does already.
if request.ServerName == r.ServerName {
return r.backfillViaFederation(ctx, request, response)
}
// someone else is requesting the backfill, try to service their request.
var err error
var front []string
// The limit defines the maximum number of events to retrieve, so it also
// defines the highest number of elements in the map below.
visited := make(map[string]bool, request.Limit)
// this will include these events which is what we want
front = request.PrevEventIDs()
info, err := r.DB.RoomInfo(ctx, request.RoomID)
if err != nil {
return err
}
if info == nil || info.IsStub {
return fmt.Errorf("PerformBackfill: missing room info for room %s", request.RoomID)
}
// Scan the event tree for events to send back.
resultNIDs, err := r.scanEventTree(ctx, *info, front, visited, request.Limit, request.ServerName)
if err != nil {
return err
}
// Retrieve events from the list that was filled previously.
var loadedEvents []gomatrixserverlib.Event
loadedEvents, err = r.loadEvents(ctx, resultNIDs)
if err != nil {
return err
}
for _, event := range loadedEvents {
response.Events = append(response.Events, event.Headered(info.RoomVersion))
}
return err
}
func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.PerformBackfillRequest, res *api.PerformBackfillResponse) error {
roomVer, err := r.roomVersion(req.RoomID)
if err != nil {
return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err)
}
requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName, req.BackwardsExtremities)
// Request 100 items regardless of what the query asks for.
// We don't want to go much higher than this.
// We can't honour exactly the limit as some sytests rely on requesting more for tests to pass
// (so we don't need to hit /state_ids which the test has no listener for)
// Specifically the test "Outbound federation can backfill events"
events, err := gomatrixserverlib.RequestBackfill(
ctx, requester,
r.KeyRing, req.RoomID, roomVer, req.PrevEventIDs(), 100)
if err != nil {
return err
}
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
// persist these new events - auth checks have already been done
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
if err != nil {
return err
}
for _, ev := range backfilledEventMap {
// now add state for these events
stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()]
if !ok {
// this should be impossible as all events returned must have pass Step 5 of the PDU checks
// which requires a list of state IDs.
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
continue
}
var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
// attempt to fetch the missing events
r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs)
// try again
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
return err
}
}
var beforeStateSnapshotNID types.StateSnapshotNID
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
return err
}
if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
}
}
// TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
res.Events = events
return nil
}
func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
info, err := r.DB.RoomInfo(ctx, roomID)
if err != nil {
return false, err
}
if info == nil {
return false, fmt.Errorf("unknown room %s", roomID)
}
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
if err != nil {
return false, err
}
events, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return false, err
}
gmslEvents := make([]gomatrixserverlib.Event, len(events))
for i := range events {
gmslEvents[i] = events[i].Event
}
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
}
// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
// best effort.
func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
backfillRequester *backfillRequester, stateIDs []string) {
servers := backfillRequester.servers
// work out which are missing
nidMap, err := r.DB.EventNIDs(ctx, stateIDs)
if err != nil {
util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
return
}
missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event
for _, id := range stateIDs {
if _, ok := nidMap[id]; !ok {
missingMap[id] = nil
}
}
util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers))
// fetch the events from federation. Loop the servers first so if we find one that works we stick with them
for _, srv := range servers {
for id, ev := range missingMap {
if ev != nil {
continue // already found
}
logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
res, err := r.FedClient.GetEvent(ctx, srv, id)
if err != nil {
logger.WithError(err).Warn("failed to get event from server")
continue
}
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
if err != nil {
logger.WithError(err).Warn("failed to load and verify event")
continue
}
logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result)
for _, res := range result {
if res.Error != nil {
logger.WithError(err).Warn("event failed PDU checks")
continue
}
missingMap[id] = res.Event
}
}
}
var newEvents []gomatrixserverlib.HeaderedEvent
for _, ev := range missingMap {
if ev != nil {
newEvents = append(newEvents, *ev)
}
}
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
persistEvents(ctx, r.DB, newEvents)
}
// TODO: Remove this when we have tests to assert correctness of this function
// nolint:gocyclo
func (r *RoomserverInternalAPI) scanEventTree(
ctx context.Context, info types.RoomInfo, front []string, visited map[string]bool, limit int,
serverName gomatrixserverlib.ServerName,
) ([]types.EventNID, error) {
var resultNIDs []types.EventNID
var err error
var allowed bool
var events []types.Event
var next []string
var pre string
// TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
// Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
// so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
// duplicate events being sent in response to /backfill requests.
initialIgnoreList := make(map[string]bool, len(visited))
for k, v := range visited {
initialIgnoreList[k] = v
}
resultNIDs = make([]types.EventNID, 0, limit)
var checkedServerInRoom bool
var isServerInRoom bool
// Loop through the event IDs to retrieve the requested events and go
// through the whole tree (up to the provided limit) using the events'
// "prev_event" key.
BFSLoop:
for len(front) > 0 {
// Prevent unnecessary allocations: reset the slice only when not empty.
if len(next) > 0 {
next = make([]string, 0)
}
// Retrieve the events to process from the database.
events, err = r.DB.EventsFromIDs(ctx, front)
if err != nil {
return resultNIDs, err
}
if !checkedServerInRoom && len(events) > 0 {
// It's nasty that we have to extract the room ID from an event, but many federation requests
// only talk in event IDs, no room IDs at all (!!!)
ev := events[0]
isServerInRoom, err = r.isServerCurrentlyInRoom(ctx, serverName, ev.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
}
checkedServerInRoom = true
}
for _, ev := range events {
// Break out of the loop if the provided limit is reached.
if len(resultNIDs) == limit {
break BFSLoop
}
if !initialIgnoreList[ev.EventID()] {
// Update the list of events to retrieve.
resultNIDs = append(resultNIDs, ev.EventNID)
}
// Loop through the event's parents.
for _, pre = range ev.PrevEventIDs() {
// Only add an event to the list of next events to process if it
// hasn't been seen before.
if !visited[pre] {
visited[pre] = true
allowed, err = r.checkServerAllowedToSeeEvent(ctx, info, pre, serverName, isServerInRoom)
if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event",
)
return resultNIDs, err
}
// If the event hasn't been seen before and the HS
// requesting to retrieve it is allowed to do so, add it to
// the list of events to retrieve.
if allowed {
next = append(next, pre)
} else {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
}
}
}
}
// Repeat the same process with the parent events we just processed.
front = next
}
return resultNIDs, err
}
// QueryStateAndAuthChain implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryStateAndAuthChain(
ctx context.Context,
@ -823,7 +425,7 @@ func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, roomInf
return nil, err
}
return r.loadStateEvents(ctx, stateEntries)
return helpers.LoadStateEvents(ctx, r.DB, stateEntries)
}
type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
@ -879,50 +481,6 @@ func getAuthChain(
return authEvents, nil
}
func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
var roomNID types.RoomNID
backfilledEventMap := make(map[string]types.Event)
for j, ev := range events {
nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs())
if err != nil { // this shouldn't happen as RequestBackfill already found them
logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events")
continue
}
authNids := make([]types.EventNID, len(nidMap))
i := 0
for _, nid := range nidMap {
authNids[i] = nid
i++
}
var stateAtEvent types.StateAtEvent
var redactedEventID string
var redactionEvent *gomatrixserverlib.Event
roomNID, stateAtEvent, redactionEvent, redactedEventID, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue
}
// If storing this event results in it being redacted, then do so.
// It's also possible for this event to be a redaction which results in another event being
// redacted, which we don't care about since we aren't returning it in this backfill.
if redactedEventID == ev.EventID() {
eventToRedact := ev.Unwrap()
redactedEvent, err := eventutil.RedactEvent(redactionEvent, &eventToRedact)
if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to redact event")
continue
}
ev = redactedEvent.Headered(ev.RoomVersion)
events[j] = ev
}
backfilledEventMap[ev.EventID()] = types.Event{
EventNID: stateAtEvent.StateEntry.EventNID,
Event: ev.Unwrap(),
}
}
return roomNID, backfilledEventMap
}
// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities(
ctx context.Context,