Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flagまわりのリファクタリング #23

Merged
merged 2 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cmd

import (
"bufio"
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -57,6 +58,12 @@ It reads the configuration file and sends the message to the specified webhook.`
return nil
}

cmdCtx := cmd.Context()
rootFlagsData, ok := cmdCtx.Value(rootFlagsCtxKey{}).(*rootFlags)
if !ok {
return errors.New("failed to get root options")
}

channelsStr := viper.GetStringMapString(configKeyChannels)
channels := make(map[string]uuid.UUID, len(channelsStr))
for k, v := range channelsStr {
Expand Down Expand Up @@ -91,7 +98,7 @@ It reads the configuration file and sends the message to the specified webhook.`
message = strings.TrimSpace(sb.String())
}

if withCodeBlock {
if rootFlagsData.codeBlock {
leadingBackQuoteCountMax := 0

for _, line := range strings.Split(message, "\n") {
Expand All @@ -105,15 +112,15 @@ It reads the configuration file and sends the message to the specified webhook.`

codeBlockBackQuote := strings.Repeat("`", max(leadingBackQuoteCountMax+1, 3))

message = fmt.Sprintf("%s%s\n%s\n%s", codeBlockBackQuote, codeBlockLang, message, codeBlockBackQuote)
message = fmt.Sprintf("%s%s\n%s\n%s", codeBlockBackQuote, rootFlagsData.codeBlockLang, message, codeBlockBackQuote)
}

channelID := uuid.Nil
if channelName != "" {
if rootFlagsData.channelName != "" {
var ok bool
channelID, ok = conf.channels[channelName]
channelID, ok = conf.channels[rootFlagsData.channelName]
if !ok {
return fmt.Errorf("channel '%s' is not found: %w", channelName, ErrChannelNotFound)
return fmt.Errorf("channel '%s' is not found: %w", rootFlagsData.channelName, ErrChannelNotFound)
}
}

Expand All @@ -133,13 +140,15 @@ It reads the configuration file and sends the message to the specified webhook.`
},
}

var (
printVersion bool
withCodeBlock bool
type rootFlags struct {
codeBlock bool
codeBlockLang string
channelName string
}

version string
var (
printVersion bool
version string
)

const (
Expand All @@ -165,6 +174,8 @@ func Execute() {
}
}

type rootFlagsCtxKey struct{}

func init() {
cobra.OnInitialize(initConfig)

Expand All @@ -178,9 +189,13 @@ func init() {
// when this action is called directly.
rootCmd.Flags().BoolVarP(&printVersion, "version", "v", false, "Print version information and quit")

rootCmd.Flags().BoolVarP(&withCodeBlock, "code", "c", false, "Send message with code block")
rootCmd.Flags().StringVarP(&codeBlockLang, "lang", "l", "", "Code block language")
rootCmd.Flags().StringVarP(&channelName, "channel", "C", "", "Channel name")
var rootFlagsData rootFlags
rootCmd.Flags().BoolVarP(&rootFlagsData.codeBlock, "code", "c", false, "Send message with code block")
rootCmd.Flags().StringVarP(&rootFlagsData.codeBlockLang, "lang", "l", "", "Code block language")
rootCmd.Flags().StringVarP(&rootFlagsData.channelName, "channel", "C", "", "Channel name")

ctx := context.WithValue(context.Background(), rootFlagsCtxKey{}, &rootFlagsData)
rootCmd.SetContext(ctx)
}

// initConfig reads in config file and ENV variables if set.
Expand Down
161 changes: 84 additions & 77 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cmd

import (
"bytes"
"context"
"fmt"
"os"
"testing"
Expand All @@ -22,6 +23,7 @@ func TestRoot(t *testing.T) {
codeBlock bool
codeBlockLang string
channelName string
printVersion bool
stdin string
args []string
}
Expand All @@ -32,68 +34,89 @@ func TestRoot(t *testing.T) {
SendMessageErr error
expectedMessage string
SkipCallSendMessage bool
expectedChannelID uuid.UUID
isError bool
expectedErr error

expectedChannelID uuid.UUID
isError bool
expectedErr error
expectedStdout string
}{
"ok": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", "", []string{"test"}},
input: input{false, "", "", false, "", []string{"test"}},
expectedMessage: "test",
expectedChannelID: uuid.Nil,
},
"コードブロックがあっても問題なし": {
webhookConfig: defaultWebhookConfig,
input: input{true, "", "", "", []string{"print('Hello, World!')"}},
input: input{true, "", "", false, "", []string{"print('Hello, World!')"}},
expectedMessage: "```\nprint('Hello, World!')\n```",
expectedChannelID: uuid.Nil,
},
"コードブロックと言語指定があっても問題なし": {
webhookConfig: defaultWebhookConfig,
input: input{true, "python", "", "", []string{"print('Hello, World!')"}},
input: input{true, "python", "", false, "", []string{"print('Hello, World!')"}},
expectedMessage: "```python\nprint('Hello, World!')\n```",
expectedChannelID: uuid.Nil,
},
"メッセージがない場合は標準入力から": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", "stdin test", nil},
input: input{false, "", "", false, "stdin test", nil},
expectedMessage: "stdin test",
expectedChannelID: uuid.Nil,
},
"メッセージがあったら標準入力は無視": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", "stdin test", []string{"test"}},
input: input{false, "", "", false, "stdin test", []string{"test"}},
expectedMessage: "test",
expectedChannelID: uuid.Nil,
},
"SendMessageがErrEmptyMessageを返す": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", "", nil},
input: input{false, "", "", false, "", nil},
SendMessageErr: client.ErrEmptyMessage,
expectedChannelID: uuid.Nil,
isError: true,
},
"メッセージにコードブロックが含まれていて、そこにコードブロックを付けても問題なし": {
webhookConfig: defaultWebhookConfig,
input: input{true, "", "", "```python\nprint('Hello, World!')\n```", nil},
input: input{true, "", "", false, "```python\nprint('Hello, World!')\n```", nil},
expectedMessage: "````\n```python\nprint('Hello, World!')\n```\n````",
expectedChannelID: uuid.Nil,
},
"チャンネル名を指定しても問題なし": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "channel", "test", nil},
input: input{false, "", "channel", false, "test", nil},
expectedMessage: "test",
expectedChannelID: channelID,
},
"チャンネル名が存在しない場合はエラー": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "notfound", "test", nil},
input: input{false, "", "notfound", false, "test", nil},
SendMessageErr: nil,
SkipCallSendMessage: true,
expectedChannelID: uuid.Nil,
isError: true,
expectedErr: ErrChannelNotFound,
},
"print version": {
webhookConfig: defaultWebhookConfig,
input: input{false, "", "", true, "", nil},
SkipCallSendMessage: true,
expectedStdout: "q version unknown\n",
},
"設定が不十分でもバージョンを表示": {
webhookConfig: webhookConfig{},
input: input{false, "", "", true, "", nil},
SkipCallSendMessage: true,
expectedStdout: "q version unknown\n",
},
"設定が不十分なのでエラーメッセージ": {
webhookConfig: webhookConfig{},
input: input{false, "", "", false, "", nil},
SkipCallSendMessage: true,
isError: true,
expectedErr: ErrEmptyConfiguration,
},
}

for description, tt := range test {
Expand All @@ -110,29 +133,21 @@ func TestRoot(t *testing.T) {
viper.Set("webhook_secret", tt.webhookConfig.secret)
viper.Set("channels", channelsStr)

withCodeBlock = tt.codeBlock
codeBlockLang = tt.codeBlockLang
channelName = tt.channelName
rootFlagsData := rootFlags{
codeBlock: tt.codeBlock,
codeBlockLang: tt.codeBlockLang,
channelName: tt.channelName,
}
rootCmd.SetContext(context.WithValue(context.Background(), rootFlagsCtxKey{}, &rootFlagsData))

printVersion = tt.printVersion
t.Cleanup(func() {
withCodeBlock = false
codeBlockLang = ""
channelName = ""
printVersion = false
Comment on lines +136 to +145
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

グローバル変数printVersionの使用を避け、テスト間の副作用を防止

テスト内でグローバル変数printVersionを直接設定・リセットしていますが、これはテスト間で状態が共有され、副作用を引き起こす可能性があります。代わりに、rootFlags構造体にprintVersionフィールドを追加し、コンテキスト経由で渡す方法を検討してください。

以下の差分を適用して修正できます:

+type rootFlags struct {
+    codeBlock     bool
+    codeBlockLang string
+    channelName   string
+    printVersion  bool
+}

 rootFlagsData := rootFlags{
     codeBlock:     tt.codeBlock,
     codeBlockLang: tt.codeBlockLang,
     channelName:   tt.channelName,
+    printVersion:  tt.printVersion,
 }

 rootCmd.SetContext(context.WithValue(context.Background(), rootFlagsCtxKey{}, &rootFlagsData))

- printVersion = tt.printVersion
- t.Cleanup(func() {
-     printVersion = false
- })

// コマンド内でグローバル変数ではなく、rootFlagsからprintVersionを取得
if rootFlags.printVersion {
    // バージョン情報を表示する処理
}

Committable suggestion was skipped due to low confidence.

})

r, w, err := os.Pipe()
require.NoError(t, err, "failed to create pipe")
stdinW := ReplaceStdin(t)

origStdin := os.Stdin
os.Stdin = r
defer func() {
os.Stdin = origStdin
r.Close()
}()

_, err = fmt.Fprint(w, tt.stdin)
require.NoError(t, err, "failed to write to pipe")
w.Close()
stdoutR := ReplaceStdout(t)

mockClient := &mock.ClientMock{
SendMessageFunc: func(message string, channelID uuid.UUID) error {
Expand All @@ -142,7 +157,12 @@ func TestRoot(t *testing.T) {

SetClient(mockClient)

_, err := fmt.Fprint(stdinW, tt.stdin)
require.NoError(t, err, "failed to write to pipe")
stdinW.Close()

cmdErr := rootCmd.RunE(rootCmd, tt.args)
os.Stdout.Close()

if tt.SkipCallSendMessage {
assert.Len(t, mockClient.SendMessageCalls(), 0)
Expand All @@ -153,6 +173,14 @@ func TestRoot(t *testing.T) {
assert.Equal(t, tt.expectedChannelID, mockClient.SendMessageCalls()[0].ChannelID)
}

if tt.expectedStdout != "" {
var buffer bytes.Buffer
_, err := buffer.ReadFrom(stdoutR)
require.NoError(t, err, "failed to read from pipe")

assert.Equal(t, tt.expectedStdout, buffer.String())
}

if tt.isError {
if tt.expectedErr != nil {
assert.ErrorIs(t, cmdErr, tt.expectedErr)
Expand All @@ -166,59 +194,38 @@ func TestRoot(t *testing.T) {
}
}

func TestRoot_NoSendMessage(t *testing.T) {
test := map[string]struct {
webhookHost string
webhookID string
webhookSecret string
args []string
printVersion bool
wantStdout string
wantErr error
}{
"print version": {"http://localhost:8080", "test", "test", []string{}, true, "q version unknown\n", nil},
"設定が不十分でもversionをprint": {"", "test", "test", []string{}, true, "q version unknown\n", nil},
"設定が不十分なのでエラーメッセージ": {"", "", "", []string{"aaa"}, false, "", ErrEmptyConfiguration},
}
// 標準出力に書き込むとそれを読めるReaderを返す。
// テスト対象の関数実行後、os.Stdoutをcloseすること。
func ReplaceStdout(t *testing.T) *os.File {
t.Helper()

for description, tt := range test {
t.Run(description, func(t *testing.T) {
viper.Set("webhook_host", tt.webhookHost)
viper.Set("webhook_id", tt.webhookID)
viper.Set("webhook_secret", tt.webhookSecret)
stdoutR, stdoutW, err := os.Pipe()
require.NoError(t, err, "failed to create pipe")

mockClient := &mock.ClientMock{
SendMessageFunc: func(message string, channelID uuid.UUID) error {
return nil
},
}
origStdout := os.Stdout
os.Stdout = stdoutW

r, w, err := os.Pipe()
require.NoError(t, err, "failed to create pipe")
origStdout := os.Stdout
os.Stdout = w
defer func() {
os.Stdout = origStdout
}()
t.Cleanup(func() {
os.Stdout = origStdout
})

printVersion = tt.printVersion
return stdoutR
}

SetClient(mockClient)
// 書き込むと標準入力に書き込まれるWriterを返す。
func ReplaceStdin(t *testing.T) *os.File {
t.Helper()

cmdErr := rootCmd.RunE(rootCmd, []string{})
w.Close()
stdinR, stdinW, err := os.Pipe()
require.NoError(t, err, "failed to create pipe")

assert.Len(t, mockClient.SendMessageCalls(), 0)
var buffer bytes.Buffer
_, err = buffer.ReadFrom(r)
require.NoError(t, err, "failed to read from pipe")
origStdin := os.Stdin
os.Stdin = stdinR

assert.Equal(t, buffer.String(), tt.wantStdout)
if tt.wantErr != nil {
assert.ErrorIs(t, tt.wantErr, cmdErr)
} else {
assert.NoError(t, cmdErr)
}
})
}
t.Cleanup(func() {
os.Stdin = origStdin
stdinR.Close()
})

return stdinW
}