Skip to content

Commit

Permalink
Merge pull request #23 from ikura-hamu/refactor/args
Browse files Browse the repository at this point in the history
flagまわりのリファクタリング
  • Loading branch information
ikura-hamu committed Sep 22, 2024
2 parents 877cca4 + de5fe04 commit 677c49f
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 89 deletions.
39 changes: 27 additions & 12 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cmd

import (
"bufio"
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -57,6 +58,12 @@ It reads the configuration file and sends the message to the specified webhook.`
return nil
}

cmdCtx := cmd.Context()
rootFlagsData, ok := cmdCtx.Value(rootFlagsCtxKey{}).(*rootFlags)
if !ok {
return errors.New("failed to get root options")
}

channelsStr := viper.GetStringMapString(configKeyChannels)
channels := make(map[string]uuid.UUID, len(channelsStr))
for k, v := range channelsStr {
Expand Down Expand Up @@ -91,7 +98,7 @@ It reads the configuration file and sends the message to the specified webhook.`
message = strings.TrimSpace(sb.String())
}

if withCodeBlock {
if rootFlagsData.codeBlock {
leadingBackQuoteCountMax := 0

for _, line := range strings.Split(message, "\n") {
Expand All @@ -105,15 +112,15 @@ It reads the configuration file and sends the message to the specified webhook.`

codeBlockBackQuote := strings.Repeat("`", max(leadingBackQuoteCountMax+1, 3))

message = fmt.Sprintf("%s%s\n%s\n%s", codeBlockBackQuote, codeBlockLang, message, codeBlockBackQuote)
message = fmt.Sprintf("%s%s\n%s\n%s", codeBlockBackQuote, rootFlagsData.codeBlockLang, message, codeBlockBackQuote)
}

channelID := uuid.Nil
if channelName != "" {
if rootFlagsData.channelName != "" {
var ok bool
channelID, ok = conf.channels[channelName]
channelID, ok = conf.channels[rootFlagsData.channelName]
if !ok {
return fmt.Errorf("channel '%s' is not found: %w", channelName, ErrChannelNotFound)
return fmt.Errorf("channel '%s' is not found: %w", rootFlagsData.channelName, ErrChannelNotFound)
}
}

Expand All @@ -133,13 +140,15 @@ It reads the configuration file and sends the message to the specified webhook.`
},
}

var (
printVersion bool
withCodeBlock bool
type rootFlags struct {
codeBlock bool
codeBlockLang string
channelName string
}

version string
var (
printVersion bool
version string
)

const (
Expand All @@ -165,6 +174,8 @@ func Execute() {
}
}

type rootFlagsCtxKey struct{}

func init() {
cobra.OnInitialize(initConfig)

Expand All @@ -178,9 +189,13 @@ func init() {
// when this action is called directly.
rootCmd.Flags().BoolVarP(&printVersion, "version", "v", false, "Print version information and quit")

rootCmd.Flags().BoolVarP(&withCodeBlock, "code", "c", false, "Send message with code block")
rootCmd.Flags().StringVarP(&codeBlockLang, "lang", "l", "", "Code block language")
rootCmd.Flags().StringVarP(&channelName, "channel", "C", "", "Channel name")
var rootFlagsData rootFlags
rootCmd.Flags().BoolVarP(&rootFlagsData.codeBlock, "code", "c", false, "Send message with code block")
rootCmd.Flags().StringVarP(&rootFlagsData.codeBlockLang, "lang", "l", "", "Code block language")
rootCmd.Flags().StringVarP(&rootFlagsData.channelName, "channel", "C", "", "Channel name")

ctx := context.WithValue(context.Background(), rootFlagsCtxKey{}, &rootFlagsData)
rootCmd.SetContext(ctx)
}

// initConfig reads in config file and ENV variables if set.
Expand Down
161 changes: 84 additions & 77 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"bytes"
"context"
"fmt"
"os"
"testing"
Expand All @@ -22,6 +23,7 @@ func TestRoot(t *testing.T) {
codeBlock bool
codeBlockLang string
channelName string
printVersion bool
stdin string
args []string
}
Expand All @@ -32,68 +34,89 @@ func TestRoot(t *testing.T) {
SendMessageErr error
expectedMessage string
SkipCallSendMessage bool
expectedChannelID uuid.UUID
isError bool
expectedErr error

expectedChannelID uuid.UUID
isError bool
expectedErr error
expectedStdout string
}{
"ok": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", "", []string{"test"}},
input: input{false, "", "", false, "", []string{"test"}},
expectedMessage: "test",
expectedChannelID: uuid.Nil,
},
"コードブロックがあっても問題なし": {
webhookConfig: defaultWebhookConfig,
input: input{true, "", "", "", []string{"print('Hello, World!')"}},
input: input{true, "", "", false, "", []string{"print('Hello, World!')"}},
expectedMessage: "```\nprint('Hello, World!')\n```",
expectedChannelID: uuid.Nil,
},
"コードブロックと言語指定があっても問題なし": {
webhookConfig: defaultWebhookConfig,
input: input{true, "python", "", "", []string{"print('Hello, World!')"}},
input: input{true, "python", "", false, "", []string{"print('Hello, World!')"}},
expectedMessage: "```python\nprint('Hello, World!')\n```",
expectedChannelID: uuid.Nil,
},
"メッセージがない場合は標準入力から": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", "stdin test", nil},
input: input{false, "", "", false, "stdin test", nil},
expectedMessage: "stdin test",
expectedChannelID: uuid.Nil,
},
"メッセージがあったら標準入力は無視": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", "stdin test", []string{"test"}},
input: input{false, "", "", false, "stdin test", []string{"test"}},
expectedMessage: "test",
expectedChannelID: uuid.Nil,
},
"SendMessageがErrEmptyMessageを返す": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", "", nil},
input: input{false, "", "", false, "", nil},
SendMessageErr: client.ErrEmptyMessage,
expectedChannelID: uuid.Nil,
isError: true,
},
"メッセージにコードブロックが含まれていて、そこにコードブロックを付けても問題なし": {
webhookConfig: defaultWebhookConfig,
input: input{true, "", "", "```python\nprint('Hello, World!')\n```", nil},
input: input{true, "", "", false, "```python\nprint('Hello, World!')\n```", nil},
expectedMessage: "````\n```python\nprint('Hello, World!')\n```\n````",
expectedChannelID: uuid.Nil,
},
"チャンネル名を指定しても問題なし": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "channel", "test", nil},
input: input{false, "", "channel", false, "test", nil},
expectedMessage: "test",
expectedChannelID: channelID,
},
"チャンネル名が存在しない場合はエラー": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "notfound", "test", nil},
input: input{false, "", "notfound", false, "test", nil},
SendMessageErr: nil,
SkipCallSendMessage: true,
expectedChannelID: uuid.Nil,
isError: true,
expectedErr: ErrChannelNotFound,
},
"print version": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", true, "", nil},
SkipCallSendMessage: true,
expectedStdout: "q version unknown\n",
},
"設定が不十分でもバージョンを表示": {
webhookConfig: webhookConfig{},
input: input{false, "", "", true, "", nil},
SkipCallSendMessage: true,
expectedStdout: "q version unknown\n",
},
"設定が不十分なのでエラーメッセージ": {
webhookConfig: webhookConfig{},
input: input{false, "", "", false, "", nil},
SkipCallSendMessage: true,
isError: true,
expectedErr: ErrEmptyConfiguration,
},
}

for description, tt := range test {
Expand All @@ -110,29 +133,21 @@ func TestRoot(t *testing.T) {
viper.Set("webhook_secret", tt.webhookConfig.secret)
viper.Set("channels", channelsStr)

withCodeBlock = tt.codeBlock
codeBlockLang = tt.codeBlockLang
channelName = tt.channelName
rootFlagsData := rootFlags{
codeBlock: tt.codeBlock,
codeBlockLang: tt.codeBlockLang,
channelName: tt.channelName,
}
rootCmd.SetContext(context.WithValue(context.Background(), rootFlagsCtxKey{}, &rootFlagsData))

printVersion = tt.printVersion
t.Cleanup(func() {
withCodeBlock = false
codeBlockLang = ""
channelName = ""
printVersion = false
})

r, w, err := os.Pipe()
require.NoError(t, err, "failed to create pipe")
stdinW := ReplaceStdin(t)

origStdin := os.Stdin
os.Stdin = r
defer func() {
os.Stdin = origStdin
r.Close()
}()

_, err = fmt.Fprint(w, tt.stdin)
require.NoError(t, err, "failed to write to pipe")
w.Close()
stdoutR := ReplaceStdout(t)

mockClient := &mock.ClientMock{
SendMessageFunc: func(message string, channelID uuid.UUID) error {
Expand All @@ -142,7 +157,12 @@ func TestRoot(t *testing.T) {

SetClient(mockClient)

_, err := fmt.Fprint(stdinW, tt.stdin)
require.NoError(t, err, "failed to write to pipe")
stdinW.Close()

cmdErr := rootCmd.RunE(rootCmd, tt.args)
os.Stdout.Close()

if tt.SkipCallSendMessage {
assert.Len(t, mockClient.SendMessageCalls(), 0)
Expand All @@ -153,6 +173,14 @@ func TestRoot(t *testing.T) {
assert.Equal(t, tt.expectedChannelID, mockClient.SendMessageCalls()[0].ChannelID)
}

if tt.expectedStdout != "" {
var buffer bytes.Buffer
_, err := buffer.ReadFrom(stdoutR)
require.NoError(t, err, "failed to read from pipe")

assert.Equal(t, tt.expectedStdout, buffer.String())
}

if tt.isError {
if tt.expectedErr != nil {
assert.ErrorIs(t, cmdErr, tt.expectedErr)
Expand All @@ -166,59 +194,38 @@ func TestRoot(t *testing.T) {
}
}

func TestRoot_NoSendMessage(t *testing.T) {
test := map[string]struct {
webhookHost string
webhookID string
webhookSecret string
args []string
printVersion bool
wantStdout string
wantErr error
}{
"print version": {"http://localhost:8080", "test", "test", []string{}, true, "q version unknown\n", nil},
"設定が不十分でもversionをprint": {"", "test", "test", []string{}, true, "q version unknown\n", nil},
"設定が不十分なのでエラーメッセージ": {"", "", "", []string{"aaa"}, false, "", ErrEmptyConfiguration},
}
// 標準出力に書き込むとそれを読めるReaderを返す。
// テスト対象の関数実行後、os.Stdoutをcloseすること。
func ReplaceStdout(t *testing.T) *os.File {
t.Helper()

for description, tt := range test {
t.Run(description, func(t *testing.T) {
viper.Set("webhook_host", tt.webhookHost)
viper.Set("webhook_id", tt.webhookID)
viper.Set("webhook_secret", tt.webhookSecret)
stdoutR, stdoutW, err := os.Pipe()
require.NoError(t, err, "failed to create pipe")

mockClient := &mock.ClientMock{
SendMessageFunc: func(message string, channelID uuid.UUID) error {
return nil
},
}
origStdout := os.Stdout
os.Stdout = stdoutW

r, w, err := os.Pipe()
require.NoError(t, err, "failed to create pipe")
origStdout := os.Stdout
os.Stdout = w
defer func() {
os.Stdout = origStdout
}()
t.Cleanup(func() {
os.Stdout = origStdout
})

printVersion = tt.printVersion
return stdoutR
}

SetClient(mockClient)
// 書き込むと標準入力に書き込まれるWriterを返す。
func ReplaceStdin(t *testing.T) *os.File {
t.Helper()

cmdErr := rootCmd.RunE(rootCmd, []string{})
w.Close()
stdinR, stdinW, err := os.Pipe()
require.NoError(t, err, "failed to create pipe")

assert.Len(t, mockClient.SendMessageCalls(), 0)
var buffer bytes.Buffer
_, err = buffer.ReadFrom(r)
require.NoError(t, err, "failed to read from pipe")
origStdin := os.Stdin
os.Stdin = stdinR

assert.Equal(t, buffer.String(), tt.wantStdout)
if tt.wantErr != nil {
assert.ErrorIs(t, tt.wantErr, cmdErr)
} else {
assert.NoError(t, cmdErr)
}
})
}
t.Cleanup(func() {
os.Stdin = origStdin
stdinR.Close()
})

return stdinW
}

0 comments on commit 677c49f

Please sign in to comment.