Add PerformInvite and refactor how errors get handled (#1158)
* Add PerformInvite and refactor how errors get handled - Rename `JoinError` to `PerformError` - Remove `error` from the API function signature entirely. This forces errors to be bundled into `PerformError` which makes it easier for callers to detect and handle errors. On network errors, HTTP clients will make a `PerformError`. * Unbreak everything; thanks Go! * Send back JSONResponse according to the PerformError * Update federation invite code too
This commit is contained in:
parent
ebaaf65c54
commit
002fe05a20
16 changed files with 469 additions and 332 deletions
|
|
@ -14,40 +14,63 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// WriteOutputEvents implements OutputRoomEventWriter
|
||||
// PerformJoin handles joining matrix rooms, including over federation by talking to the federationsender.
|
||||
func (r *RoomserverInternalAPI) PerformJoin(
|
||||
ctx context.Context,
|
||||
req *api.PerformJoinRequest,
|
||||
res *api.PerformJoinResponse,
|
||||
) error {
|
||||
) {
|
||||
roomID, err := r.performJoin(ctx, req)
|
||||
if err != nil {
|
||||
perr, ok := err.(*api.PerformError)
|
||||
if ok {
|
||||
res.Error = perr
|
||||
} else {
|
||||
res.Error = &api.PerformError{
|
||||
Msg: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
res.RoomID = roomID
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) performJoin(
|
||||
ctx context.Context,
|
||||
req *api.PerformJoinRequest,
|
||||
) (string, error) {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
res.Error = api.JoinErrorBadRequest
|
||||
return fmt.Errorf("Supplied user ID %q in incorrect format", req.UserID)
|
||||
return "", &api.PerformError{
|
||||
Code: api.PerformErrorBadRequest,
|
||||
Msg: fmt.Sprintf("Supplied user ID %q in incorrect format", req.UserID),
|
||||
}
|
||||
}
|
||||
if domain != r.Cfg.Matrix.ServerName {
|
||||
res.Error = api.JoinErrorBadRequest
|
||||
return fmt.Errorf("User %q does not belong to this homeserver", req.UserID)
|
||||
return "", &api.PerformError{
|
||||
Code: api.PerformErrorBadRequest,
|
||||
Msg: fmt.Sprintf("User %q does not belong to this homeserver", req.UserID),
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(req.RoomIDOrAlias, "!") {
|
||||
return r.performJoinRoomByID(ctx, req, res)
|
||||
return r.performJoinRoomByID(ctx, req)
|
||||
}
|
||||
if strings.HasPrefix(req.RoomIDOrAlias, "#") {
|
||||
return r.performJoinRoomByAlias(ctx, req, res)
|
||||
return r.performJoinRoomByAlias(ctx, req)
|
||||
}
|
||||
return "", &api.PerformError{
|
||||
Code: api.PerformErrorBadRequest,
|
||||
Msg: fmt.Sprintf("Room ID or alias %q is invalid", req.RoomIDOrAlias),
|
||||
}
|
||||
res.Error = api.JoinErrorBadRequest
|
||||
return fmt.Errorf("Room ID or alias %q is invalid", req.RoomIDOrAlias)
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) performJoinRoomByAlias(
|
||||
ctx context.Context,
|
||||
req *api.PerformJoinRequest,
|
||||
res *api.PerformJoinResponse,
|
||||
) error {
|
||||
) (string, error) {
|
||||
// Get the domain part of the room alias.
|
||||
_, domain, err := gomatrixserverlib.SplitID('#', req.RoomIDOrAlias)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias)
|
||||
return "", fmt.Errorf("Alias %q is not in the correct format", req.RoomIDOrAlias)
|
||||
}
|
||||
req.ServerNames = append(req.ServerNames, domain)
|
||||
|
||||
|
|
@ -65,7 +88,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias(
|
|||
err = r.fsAPI.PerformDirectoryLookup(ctx, &dirReq, &dirRes)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("error looking up alias %q", req.RoomIDOrAlias)
|
||||
return fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
|
||||
return "", fmt.Errorf("Looking up alias %q over federation failed: %w", req.RoomIDOrAlias, err)
|
||||
}
|
||||
roomID = dirRes.RoomID
|
||||
req.ServerNames = append(req.ServerNames, dirRes.ServerNames...)
|
||||
|
|
@ -73,18 +96,18 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias(
|
|||
// Otherwise, look up if we know this room alias locally.
|
||||
roomID, err = r.DB.GetRoomIDForAlias(ctx, req.RoomIDOrAlias)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err)
|
||||
return "", fmt.Errorf("Lookup room alias %q failed: %w", req.RoomIDOrAlias, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If the room ID is empty then we failed to look up the alias.
|
||||
if roomID == "" {
|
||||
return fmt.Errorf("Alias %q not found", req.RoomIDOrAlias)
|
||||
return "", fmt.Errorf("Alias %q not found", req.RoomIDOrAlias)
|
||||
}
|
||||
|
||||
// If we do, then pluck out the room ID and continue the join.
|
||||
req.RoomIDOrAlias = roomID
|
||||
return r.performJoinRoomByID(ctx, req, res)
|
||||
return r.performJoinRoomByID(ctx, req)
|
||||
}
|
||||
|
||||
// TODO: Break this function up a bit
|
||||
|
|
@ -92,19 +115,14 @@ func (r *RoomserverInternalAPI) performJoinRoomByAlias(
|
|||
func (r *RoomserverInternalAPI) performJoinRoomByID(
|
||||
ctx context.Context,
|
||||
req *api.PerformJoinRequest,
|
||||
res *api.PerformJoinResponse, // nolint:unparam
|
||||
) error {
|
||||
// By this point, if req.RoomIDOrAlias contained an alias, then
|
||||
// it will have been overwritten with a room ID by performJoinRoomByAlias.
|
||||
// We should now include this in the response so that the CS API can
|
||||
// return the right room ID.
|
||||
res.RoomID = req.RoomIDOrAlias
|
||||
|
||||
) (string, error) {
|
||||
// Get the domain part of the room ID.
|
||||
_, domain, err := gomatrixserverlib.SplitID('!', req.RoomIDOrAlias)
|
||||
if err != nil {
|
||||
res.Error = api.JoinErrorBadRequest
|
||||
return fmt.Errorf("Room ID %q is invalid", req.RoomIDOrAlias)
|
||||
return "", &api.PerformError{
|
||||
Code: api.PerformErrorBadRequest,
|
||||
Msg: fmt.Sprintf("Room ID %q is invalid: %s", req.RoomIDOrAlias, err),
|
||||
}
|
||||
}
|
||||
req.ServerNames = append(req.ServerNames, domain)
|
||||
|
||||
|
|
@ -118,7 +136,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
|||
Redacts: "",
|
||||
}
|
||||
if err = eb.SetUnsigned(struct{}{}); err != nil {
|
||||
return fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||
return "", fmt.Errorf("eb.SetUnsigned: %w", err)
|
||||
}
|
||||
|
||||
// It is possible for the request to include some "content" for the
|
||||
|
|
@ -129,7 +147,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
|||
}
|
||||
req.Content["membership"] = gomatrixserverlib.Join
|
||||
if err = eb.SetContent(req.Content); err != nil {
|
||||
return fmt.Errorf("eb.SetContent: %w", err)
|
||||
return "", fmt.Errorf("eb.SetContent: %w", err)
|
||||
}
|
||||
|
||||
// First work out if this is in response to an existing invite
|
||||
|
|
@ -142,7 +160,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
|||
// Check if there's an invite pending.
|
||||
_, inviterDomain, ierr := gomatrixserverlib.SplitID('@', inviteSender)
|
||||
if ierr != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
return "", fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
|
||||
// Check that the domain isn't ours. If it's local then we don't
|
||||
|
|
@ -154,7 +172,7 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
|||
req.ServerNames = append(req.ServerNames, inviterDomain)
|
||||
|
||||
// Perform a federated room join.
|
||||
return r.performFederatedJoinRoomByID(ctx, req, res)
|
||||
return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -205,9 +223,12 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
|||
if err = r.InputRoomEvents(ctx, &inputReq, &inputRes); err != nil {
|
||||
var notAllowed *gomatrixserverlib.NotAllowed
|
||||
if errors.As(err, ¬Allowed) {
|
||||
res.Error = api.JoinErrorNotAllowed
|
||||
return "", &api.PerformError{
|
||||
Code: api.PerformErrorNotAllowed,
|
||||
Msg: fmt.Sprintf("InputRoomEvents auth failed: %s", err),
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("r.InputRoomEvents: %w", err)
|
||||
return "", fmt.Errorf("r.InputRoomEvents: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -216,25 +237,30 @@ func (r *RoomserverInternalAPI) performJoinRoomByID(
|
|||
// room. If it is then there's nothing more to do - the room just
|
||||
// hasn't been created yet.
|
||||
if domain == r.Cfg.Matrix.ServerName {
|
||||
res.Error = api.JoinErrorNoRoom
|
||||
return fmt.Errorf("Room ID %q does not exist", req.RoomIDOrAlias)
|
||||
return "", &api.PerformError{
|
||||
Code: api.PerformErrorNoRoom,
|
||||
Msg: fmt.Sprintf("Room ID %q does not exist", req.RoomIDOrAlias),
|
||||
}
|
||||
}
|
||||
|
||||
// Perform a federated room join.
|
||||
return r.performFederatedJoinRoomByID(ctx, req, res)
|
||||
return req.RoomIDOrAlias, r.performFederatedJoinRoomByID(ctx, req)
|
||||
|
||||
default:
|
||||
// Something else went wrong.
|
||||
return fmt.Errorf("Error joining local room: %q", err)
|
||||
return "", fmt.Errorf("Error joining local room: %q", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// By this point, if req.RoomIDOrAlias contained an alias, then
|
||||
// it will have been overwritten with a room ID by performJoinRoomByAlias.
|
||||
// We should now include this in the response so that the CS API can
|
||||
// return the right room ID.
|
||||
return req.RoomIDOrAlias, nil
|
||||
}
|
||||
|
||||
func (r *RoomserverInternalAPI) performFederatedJoinRoomByID(
|
||||
ctx context.Context,
|
||||
req *api.PerformJoinRequest,
|
||||
res *api.PerformJoinResponse, // nolint:unparam
|
||||
) error {
|
||||
// Try joining by all of the supplied server names.
|
||||
fedReq := fsAPI.PerformJoinRequest{
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue