Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,14 @@ PLUGIN_PACKAGES += mattermost-plugin-calls-v1.11.0
PLUGIN_PACKAGES += mattermost-plugin-github-v2.5.0
PLUGIN_PACKAGES += mattermost-plugin-gitlab-v1.12.0
PLUGIN_PACKAGES += mattermost-plugin-jira-v4.5.1
PLUGIN_PACKAGES += mattermost-plugin-playbooks-v2.6.2
PLUGIN_PACKAGES += mattermost-plugin-playbooks-v2.7.0
PLUGIN_PACKAGES += mattermost-plugin-servicenow-v2.4.0
PLUGIN_PACKAGES += mattermost-plugin-zoom-v1.11.0
PLUGIN_PACKAGES += mattermost-plugin-agents-v1.7.2
PLUGIN_PACKAGES += mattermost-plugin-boards-v9.2.2
PLUGIN_PACKAGES += mattermost-plugin-user-survey-v1.1.1
PLUGIN_PACKAGES += mattermost-plugin-mscalendar-v1.5.0
PLUGIN_PACKAGES += mattermost-plugin-msteams-meetings-v2.3.0
PLUGIN_PACKAGES += mattermost-plugin-msteams-meetings-v2.4.0
PLUGIN_PACKAGES += mattermost-plugin-metrics-v0.7.0
PLUGIN_PACKAGES += mattermost-plugin-channel-export-v1.3.0

Expand Down
7 changes: 7 additions & 0 deletions server/channels/api4/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -1702,13 +1702,20 @@ func rewriteMessage(c *Context, w http.ResponseWriter, r *http.Request) {
return
}

// Validate root_id if provided
if req.RootID != "" && !model.IsValidId(req.RootID) {
c.SetInvalidParam("root_id")
return
}

// Call app layer to handle business logic
response, appErr := c.App.RewriteMessage(
c.AppContext,
req.AgentID,
req.Message,
req.Action,
req.CustomPrompt,
req.RootID,
)
if appErr != nil {
c.Err = appErr
Expand Down
173 changes: 152 additions & 21 deletions server/channels/app/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -3089,8 +3089,21 @@ func (a *App) RewriteMessage(
message string,
action model.RewriteAction,
customPrompt string,
rootID string,
) (*model.RewriteResponse, *model.AppError) {
userPrompt := getRewritePromptForAction(action, message, customPrompt)
// Build thread context if rootID is provided
var threadContext string
if rootID != "" {
context, appErr := a.buildThreadContextForRewrite(rctx, rootID)
if appErr != nil {
// Log error but continue without context rather than failing the rewrite
rctx.Logger().Warn("Failed to build thread context for rewrite", mlog.String("root_id", rootID), mlog.Err(appErr))
} else {
threadContext = context
}
}

userPrompt := getRewritePromptForAction(action, message, customPrompt, threadContext)
if userPrompt == "" {
return nil, model.NewAppError("RewriteMessage", "app.post.rewrite.invalid_action", nil, fmt.Sprintf("invalid action: %s", action), 400)
}
Expand Down Expand Up @@ -3121,38 +3134,156 @@ func (a *App) RewriteMessage(
return &response, nil
}

// getRewritePromptForAction returns the appropriate prompt and system prompt for the given rewrite action
func getRewritePromptForAction(action model.RewriteAction, message string, customPrompt string) string {
if message == "" {
return fmt.Sprintf(`Write according to these instructions: %s`, customPrompt)
// buildThreadContextForRewrite builds context from root post + last 10 posts in the thread
func (a *App) buildThreadContextForRewrite(rctx request.CTX, rootID string) (string, *model.AppError) {
const maxContextPosts = 10

// Get the thread posts
postList, appErr := a.GetPostThread(rctx, rootID, model.GetPostsOptions{}, rctx.Session().UserId)
if appErr != nil {
return "", appErr
}

if postList == nil || len(postList.Posts) == 0 {
return "", nil
}

// Get root post
rootPost, ok := postList.Posts[rootID]
if !ok {
return "", nil
}

// Skip if root post is a system post or deleted
if strings.HasPrefix(rootPost.Type, model.PostSystemMessagePrefix) || rootPost.DeleteAt > 0 {
return "", nil
}

// Collect reply posts, filtering out system posts and deleted posts
var replies []*model.Post
for _, postID := range postList.Order {
if postID == rootID {
continue // Skip root post
}
post, ok := postList.Posts[postID]
if !ok {
continue
}
// Skip system posts
if strings.HasPrefix(post.Type, model.PostSystemMessagePrefix) {
continue
}
// Skip deleted posts
if post.DeleteAt > 0 {
continue
}
replies = append(replies, post)
}

// Get last maxContextPosts replies
var contextReplies []*model.Post
startIdx := 0
if len(replies) > maxContextPosts {
startIdx = len(replies) - maxContextPosts
}
contextReplies = replies[startIdx:]

// Get user profiles for all posts in context
userIDs := []string{rootPost.UserId}
for _, reply := range contextReplies {
userIDs = append(userIDs, reply.UserId)
}
slices.Sort(userIDs)
userIDs = slices.Compact(userIDs)

users, appErr := a.GetUsersByIds(rctx, userIDs, &store.UserGetByIdsOpts{})
if appErr != nil {
return "", appErr
}

userMap := make(map[string]string, len(users))
for _, user := range users {
userMap[user.Id] = user.Username
}

switch action {
case model.RewriteActionCustom:
return fmt.Sprintf(`%s
// Build context string
var contextBuilder strings.Builder
contextBuilder.WriteString("Thread context:\n")

rootUsername := userMap[rootPost.UserId]
if rootUsername == "" {
rootUsername = "Unknown"
}
contextBuilder.WriteString(fmt.Sprintf("Root post (%s): %s\n", rootUsername, rootPost.Message))

if len(contextReplies) > 0 {
contextBuilder.WriteString("\nRecent replies:\n")
for _, reply := range contextReplies {
username := userMap[reply.UserId]
if username == "" {
username = "Unknown"
}
contextBuilder.WriteString(fmt.Sprintf("- %s: %s\n", username, reply.Message))
}
}

return contextBuilder.String(), nil
}

// getRewritePromptForAction returns the appropriate prompt and system prompt for the given rewrite action
func getRewritePromptForAction(action model.RewriteAction, message string, customPrompt string, threadContext string) string {
var actionPrompt string

if message == "" {
actionPrompt = fmt.Sprintf(`Write according to these instructions: %s`, customPrompt)
} else {
switch action {
case model.RewriteActionCustom:
actionPrompt = fmt.Sprintf(`%s

%s`, customPrompt, message)

case model.RewriteActionShorten:
return fmt.Sprintf(`Make this up to 2 to 3 times shorter: %s`, message)
case model.RewriteActionShorten:
actionPrompt = fmt.Sprintf(`Make this up to 2 to 3 times shorter: %s`, message)

case model.RewriteActionElaborate:
actionPrompt = fmt.Sprintf(`Make this up to 2 to 3 times longer, using Markdown if necessary: %s`, message)

case model.RewriteActionElaborate:
return fmt.Sprintf(`Make this up to 2 to 3 times longer, using Markdown if necessary: %s`, message)
case model.RewriteActionImproveWriting:
actionPrompt = fmt.Sprintf(`Improve this writing, using Markdown if necessary: %s`, message)

case model.RewriteActionImproveWriting:
return fmt.Sprintf(`Improve this writing, using Markdown if necessary: %s`, message)
case model.RewriteActionFixSpelling:
actionPrompt = fmt.Sprintf(`Fix spelling and grammar: %s`, message)

case model.RewriteActionFixSpelling:
return fmt.Sprintf(`Fix spelling and grammar: %s`, message)
case model.RewriteActionSimplify:
actionPrompt = fmt.Sprintf(`Simplify this: %s`, message)

case model.RewriteActionSimplify:
return fmt.Sprintf(`Simplify this: %s`, message)
case model.RewriteActionSummarize:
actionPrompt = fmt.Sprintf(`Summarize this, using Markdown if necessary: %s`, message)

case model.RewriteActionSummarize:
return fmt.Sprintf(`Summarize this, using Markdown if necessary: %s`, message)
default:
// Invalid action - return empty string to trigger validation error
return ""
}
}

return ""
// If no action prompt was generated, return empty string
if actionPrompt == "" {
return ""
}

// Build final prompt with thread context if available
if threadContext != "" {
var promptBuilder strings.Builder
promptBuilder.WriteString("=== THREAD CONTEXT (for reference only) ===\n")
promptBuilder.WriteString(threadContext)
promptBuilder.WriteString("\n\n=== REWRITE TASK ===\n")
promptBuilder.WriteString(actionPrompt)
promptBuilder.WriteString("\n\nRewrite the message considering the thread context above.")
return promptBuilder.String()
}

return actionPrompt
}

// RevealPost reveals a burn-on-read post for a specific user, creating a read receipt
Expand Down
2 changes: 1 addition & 1 deletion server/channels/app/slashcommands/command_mute.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (*MuteProvider) DoCommand(a *app.App, rctx request.CTX, args *model.Command

channelMember, err := a.ToggleMuteChannel(rctx, channel.Id, args.UserId)
if err != nil {
return &model.CommandResponse{Text: args.T("api.command_mute.not_member.error", map[string]any{"Channel": channelName}), ResponseType: model.CommandResponseTypeEphemeral}
return &model.CommandResponse{Text: args.T("api.command_mute.error", map[string]any{"Channel": channelName}), ResponseType: model.CommandResponseTypeEphemeral}
}

// Direct and Group messages won't have a nice channel title, omit it
Expand Down
5 changes: 3 additions & 2 deletions server/channels/app/slashcommands/command_mute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,14 @@ func TestMuteCommandNotMember(t *testing.T) {

cmd := &MuteProvider{}

// First mute the channel
// Muting a channel that the user is not a member of should return
// the same error as a non-existent channel to prevent channel enumeration
resp := cmd.DoCommand(th.App, th.Context, &model.CommandArgs{
T: i18n.IdentityTfunc(),
ChannelId: channel1.Id,
UserId: th.BasicUser.Id,
}, channel2.Name)
assert.Equal(t, "api.command_mute.not_member.error", resp.Text)
assert.Equal(t, "api.command_mute.error", resp.Text)
}

func TestMuteCommandNotChannel(t *testing.T) {
Expand Down
4 changes: 0 additions & 4 deletions server/i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1261,10 +1261,6 @@
"id": "api.command_mute.no_channel.error",
"translation": "Could not find the specified channel. Please use the [channel handle](https://docs.mattermost.com/messaging/managing-channels.html#naming-a-channel) to identify channels."
},
{
"id": "api.command_mute.not_member.error",
"translation": "Could not mute channel {{.Channel}} as you are not a member."
},
{
"id": "api.command_mute.success_mute",
"translation": "You will not receive notifications for {{.Channel}} until channel mute is turned off."
Expand Down
5 changes: 3 additions & 2 deletions server/public/model/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -1188,14 +1188,15 @@ type RewriteRequest struct {
Message string `json:"message"`
Action RewriteAction `json:"action"`
CustomPrompt string `json:"custom_prompt,omitempty"`
RootID string `json:"root_id,omitempty"`
}

type RewriteResponse struct {
RewrittenText string `json:"rewritten_text"`
}

const RewriteSystemPrompt = `You are a JSON API that rewrites text. Your response must be valid JSON only.
Return this exact format: {"rewritten_text":"content"}.
const RewriteSystemPrompt = `You are a JSON API that rewrites text. Your response must be valid JSON only.
Return this exact format: {"rewritten_text":"content"}.
Do not use markdown, code blocks, or any formatting. Start with { and end with }.`

// ReportPostOptionsCursor contains cursor information for pagination.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ jest.mock('components/common/agents/agent_dropdown', () => ({

jest.mock('components/widgets/inputs/input/input', () => {
const React = require('react');
const ForwardRefComponent = React.forwardRef(({placeholder, value, onChange, onKeyDown, disabled, inputPrefix}: any, ref: any) => (
const ForwardRefComponent = React.forwardRef(({placeholder, label, value, onChange, onKeyDown, disabled, inputPrefix}: any, ref: any) => (
<div data-testid='prompt-input'>
{inputPrefix}
<input
ref={ref}
placeholder={placeholder}
data-label={label}
value={value}
onChange={onChange}
onKeyDown={onKeyDown}
Expand Down Expand Up @@ -304,6 +305,7 @@ describe('RewriteMenu', () => {
);
let input = screen.getByTestId('prompt-input-field');
expect(input).toHaveAttribute('placeholder', 'Ask AI to edit message...');
expect(input).toHaveAttribute('data-label', 'Ask AI to edit message...');

rerender(
<RewriteMenu
Expand All @@ -313,6 +315,7 @@ describe('RewriteMenu', () => {
);
input = screen.getByTestId('prompt-input-field');
expect(input).toHaveAttribute('placeholder', 'Create a new message...');
expect(input).toHaveAttribute('data-label', 'Create a new message...');

rerender(
<RewriteMenu
Expand All @@ -323,6 +326,7 @@ describe('RewriteMenu', () => {
);
input = screen.getByTestId('prompt-input-field');
expect(input).toHaveAttribute('placeholder', 'What would you like AI to do next?');
expect(input).toHaveAttribute('data-label', 'What would you like AI to do next?');
});

test('should not render agent dropdown when processing', () => {
Expand All @@ -348,4 +352,3 @@ describe('RewriteMenu', () => {
expect(setSelectedAgentId).toHaveBeenCalledWith('agent2');
});
});

Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ export default function RewriteMenu({
onBotSelect={setSelectedAgentId}
bots={agents}
disabled={isProcessing}
showLabel={true}
/>
)}
{isProcessing &&
Expand All @@ -170,6 +171,7 @@ export default function RewriteMenu({
<Input
ref={customPromptRef}
inputPrefix={<CreationOutlineIcon size={18}/>}
label={placeholderText}
placeholder={placeholderText}
disabled={isProcessing}
value={prompt}
Expand Down Expand Up @@ -264,4 +266,3 @@ export default function RewriteMenu({
</Menu.Container>
);
}

Loading
Loading