前回の続きです。前回までで一応 CONNECTパケットをstructに変換する処理ができました。これでクライアントからのCONNECTパケットをサーバー側で解釈できます。
今回は、CONNECTに対するレスポンスであるCONNACKに取り掛かります。また、実際にサーバーとして動かし、mosquitto_clientと実際にMQTT通信(CONNECT->CONNACK)ができるようにします。その過程で、今まで単純に []byte
として扱っていた部分を io.Reader
に書き直すことになりました。
目次。
CONNACKパケット
CONNECTパケットを受け取ったサーバーは、クライアントにCONNACKパケットを返す。
CONNACKパケットは、固定ヘッダーと可変ヘッダーから構成される。ペイロードはない。可変ヘッダーは2byte。なので固定ヘッダーの ReminingLength は2で固定値となる。
Session Present と Connect Return Code
CONNACKパケットの可変ヘッダーは、以下の情報を持つ。
- Session Present
- Connect Return code
Session PresentはCONNECTパケットで指定されたClientIDとのセッションがサーバーに管理されているかどうかを示す。まだセッションについては考えれていないので常に 0
をセットすることにする。
Connect Return codeは、何通りかある。正常な場合は 0x00
。クライアントから指定されたMQTTのバージョンを受け入れられない場合は 0x01
。クライアントから指定されたClient Identifierを受け入れられない場合は 0x02
。現時点ではこの3パターンを実装する。
以下、Connactパケットの実装。
type ConnackVariableHeader struct { SessionPresent bool ReturnCode uint8 } type Connack struct { FixedHeader ConnackVariableHeader } func NewConnackForAccepted() Connack { result := newConnack() result.ReturnCode = 0 return result } func NewConnackForRefusedByUnacceptableProtocolVersion() Connack { result := newConnack() result.ReturnCode = 1 return result } func NewConnackForRefusedByIdentifierRejected() Connack { result := newConnack() result.ReturnCode = 2 return result } func newConnack() Connack { fixedHeader := FixedHeader{ PacketType: 2, RemainingLength: 2, } // TODO SessionPresentは固定にしておく variableHeader := ConnackVariableHeader{SessionPresent: false} return Connack{fixedHeader, variableHeader} }
struct から []byte へ変換
サーバーからクライアントへCONNACKパケットを返す際に、 Connack
structをバイト列に変換する必要がある。 ToBytes()
メソッドを実装する。
func (c Connack) ToBytes() []byte { var result []byte result = append(result, c.FixedHeader.ToBytes()...) result = append(result, c.ConnackVariableHeader.ToBytes()...) return result }
FixedHeader
と ConnackVariableHeader
にも ToBytes()
メソッドを実装する。
ConnackVariableHeader
func (h ConnackVariableHeader) ToBytes() []byte { var result []byte if h.SessionPresent { result = append(result, 1) } else { result = append(result, 0) } result = append(result, h.ReturnCode) return result }
FixedHeader
Remining Lengthのencodeのロジックは仕様に書いてある。
func (h FixedHeader) ToBytes() []byte { var result []byte b := h.PacketType << 4 result = append(result, b) remainingLength := encodeRemainingLength(h.RemainingLength) result = append(result, remainingLength...) return result } func encodeRemainingLength(x uint) []byte { var encodedByte byte var result []byte for { encodedByte = byte(x % 128) x = x / 128 if x > 0 { encodedByte = encodedByte | 128 } result = append(result, encodedByte) if x <= 0 { break } } return result }
メソッドのレシーバー
メソッドのレシーバーをstructにするべきかstructのポインタにするべきか、という話がA Tour of GoやEffective Goに書いてある。
個人的には可能な限り不変にしたいのでメソッドのレシーバーはポインタにしたくないけど、不変にしたいならstructのフィールドを全部プライベートにするところまでやらないと意味ないよなぁ。参照はできるけど変更はできない修飾子があったら嬉しい、けどGoにはない。
サーバー実装
サーバーを実装する。複数クライアントを同時に捌くことは後で考える。まずは以下の流れを実現する。
- TCPをListen
- 固定ヘッダーを取り出す
- PacketTypeに応じたhandlerに処理を移譲する
- CONNECTパケットを解釈して処理する
- CONNACKパケットをクライアントに返す
- コネクションを切断する
TCPサーバーは以前にechoサーバーを作ったのでそれを参考に。
netパッケージの Listen
と Accept
を使って、 net.Conn
を取得する。 net.Conn
を使ってクライアントから送られてきたバイト列を取得する。
net package - net - Go Packages
[]byteからbufio.Readerに変更する
サーバーを実装しようと思ったら早速困ったことに。
net.Conn
からバイト列を取得して、これまでに実装してきた packet.ToFixedHeader(bs []byte)
を呼び出したいが、この []byte
の長さはどうしたらいいのだろう。。ToFixedHeader関数でバイト列を引数にとるようにしてるのが設計ミスっぽい。 []byte
ではなく bufio.Reader
を引数にとるように変更する。
io.Reader
周りは以下が参考になる。
まずはToFixedHeaderの引数を変更する
- func ToFixedHeader(bs []byte) (FixedHeader, error) + func ToFixedHeader(r *bufio.Reader) (FixedHeader, error)
1バイトの取得は ReadByte()
を使う。 https://golang.org/pkg/bufio/#Reader.ReadByte
- b := bs[0] + b, err := r.ReadByte()
ToConnectVariableHeaderの引数も変更する
- func ToConnectVariableHeader(fixedHeader FixedHeader, bs []byte) (ConnectVariableHeader, error) + func ToConnectVariableHeader(fixedHeader FixedHeader, r *bufio.Reader) (ConnectVariableHeader, error)
nバイトの取得は io.ReadFull
を使う。 https://golang.org/pkg/io/#ReadFull
- if !isValidProtocolName(bs[:6]) {` + protocolName := make([]byte, 6) + _, err := io.ReadFull(r, protocolName) + if err != nil || !isValidProtocolName(protocolName) {
ToConnectPayloadの引数も変更する。
- func ToConnectPayload(bs []byte) (ConnectPayload, error) + func ToConnectPayload(r *bufio.Reader) (ConnectPayload, error)
diff全体は大きいので最後に載っけておく。
テストでbufio.Readerをどうするか
[]byte
から bufio.Reader
に変えたことで、テストコードにおいて引数として []byte{ 0x10, 0x00 }
というような値を渡せなくなってしまった。困った。
こういう時は bytes.Buffer
を使う。 bytes.Buffer
は []byte
を保持できて、かつ io.Reader
インタフェースを満たしている。
- []byte{0x1B, 0x7F} + bufio.NewReader(bytes.NewBuffer([]byte{0x1B, 0x7F}))
CONNECTパケットのhandlerとサーバーを実装する
- handler/connect_handler.goを実装。エラーハンドリングは次回やる。
- server.goを実装
- main関数も実装
connect_handler.go
package handler import ( "bufio" "fmt" "github.com/bati11/oreno-mqtt/mqtt/packet" ) // CONNECTパケットの可変ヘッダーのバイト数 var variableHeaderLength = 10 func HandleConnect(fixedHeader packet.FixedHeader, r *bufio.Reader) (packet.Connack, error) { fmt.Printf("HandleConnect\n") variableHeader, err := packet.ToConnectVariableHeader(fixedHeader, r) if err != nil { // TODO err応じたCONNACKを生成して返す return packet.NewConnackForRefusedByUnacceptableProtocolVersion(), nil } payload, err := packet.ToConnectPayload(r) if err != nil { // TODO err応じたCONNACKを生成して返す return packet.NewConnackForRefusedByIdentifierRejected(), nil } // TODO variableHeaderとpayloadを使って何かしらの処理 fmt.Printf(" %#v\n", variableHeader) fmt.Printf(" %#v\n", payload) return packet.NewConnackForAccepted(), nil }
server.go
package mqtt import ( "bufio" "fmt" "net" "github.com/bati11/oreno-mqtt/mqtt/handler" "github.com/bati11/oreno-mqtt/mqtt/packet" ) func Run() { ln, err := net.Listen("tcp", "localhost:1883") if err != nil { panic(err) } fmt.Println("server starts at localhost:1883") conn, err := ln.Accept() if err != nil { panic(err) } defer conn.Close() r := bufio.NewReader(conn) fixedHeader, err := packet.ToFixedHeader(r) if err != nil { panic(err) } switch fixedHeader.PacketType { case 1: connack, err := handler.HandleConnect(fixedHeader, r) if err != nil { panic(err) } _, err = conn.Write(connack.ToBytes()) if err != nil { panic(err) } } }
main.go
package main import "github.com/bati11/oreno-mqtt/mqtt" func main() { mqtt.Run() }
CONNECTパケットを送信する
モリモリと実装したので、実際にMQTTパケットを送って動作確認をする。
初回でやった $ mosquitto_pub -t hoge -m "Hello"
を自作のサーバーに対して送信する。Wiresharkを起動しておく。
自作サーバーを起動する。
$ go run app/main.go server starts at localhost:1883
CONNECTパケットを送信する。
$ mosquitto_pub -t hoge -m "Hello"
Wiresharkを見てみると...
できたー!!!ちゃんとCONNACKパケットが返ってる!
おしまい
まず、CONNACKパケットを表すstructを実装しました。その次に、 []byte
から各種structを生成する処理を bufio.Reader
から各種structを生成するように実装を書き換えました。最後に、TCPサーバーを起動して、 mosquitto_pub
コマンドでCONNECTパケットを実際に送りました。
次回はエラーハンドリングのTODOのところから。
今回の学び
- CONNACKパケット
- メソッドのレシーバー
- GoでTCPサーバー
- io.Reader周り
[]byte
から bufio.Reader
に書き換えた時の差分
--- a/mqtt/packet/connect_payload.go +++ b/mqtt/packet/connect_payload.go @@ -1,7 +1,9 @@ package packet import ( + "bufio" "encoding/binary" + "io" "regexp" "github.com/pkg/errors" @@ -13,17 +15,20 @@ type ConnectPayload struct { var clientIDRegex = regexp.MustCompile("^[a-zA-Z0-9-|]*$") -func ToConnectPayload(bs []byte) (ConnectPayload, error) { - if len(bs) < 3 { - return ConnectPayload{}, errors.New("payload length is invalid") +func ToConnectPayload(r *bufio.Reader) (ConnectPayload, error) { + lengthBytes := make([]byte, 2) + _, err := io.ReadFull(r, lengthBytes) + if err != nil { + return ConnectPayload{}, err } - length := binary.BigEndian.Uint16(bs[0:2]) - var clientID string - if len(bs) < 2+int(length) { - return ConnectPayload{}, errors.New("specified length is not equals ClientID length") - } else { - clientID = string(bs[2 : 2+length]) + length := binary.BigEndian.Uint16(lengthBytes) + + clientIDBytes := make([]byte, length) + _, err = io.ReadFull(r, clientIDBytes) + if err != nil { + return ConnectPayload{}, err } + clientID := string(clientIDBytes) if len(clientID) < 1 || len(clientID) > 23 { return ConnectPayload{}, errors.New("ClientID length is invalid") }
--- a/mqtt/packet/connect_payload_test.go +++ b/mqtt/packet/connect_payload_test.go @@ -1,13 +1,15 @@ package packet import ( + "bufio" + "bytes" "reflect" "testing" ) func TestToConnectPayload(t *testing.T) { type args struct { - bs []byte + r *bufio.Reader } tests := []struct { name string @@ -17,38 +19,38 @@ func TestToConnectPayload(t *testing.T) { }{ { name: "ClientIDが1文字", - args: args{[]byte{0x00, 0x01, 'a'}}, + args: args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x01, 'a'}))}, want: ConnectPayload{ClientID: "a"}, wantErr: false, }, { name: "ペイロードが0byte", - args: args{[]byte{}}, + args: args{bufio.NewReader(bytes.NewBuffer([]byte{}))}, want: ConnectPayload{}, wantErr: true, }, { name: "ClientIDが23文字を超える", - args: args{[]byte{0x00, 0x18, '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 'b', 'c', 'd'}}, + args: args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x18, '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '0', 'a', 'b', 'c', 'd'}))}, want: ConnectPayload{}, wantErr: true, }, { name: "使えない文字がある", - args: args{[]byte{0x00, 0x02, '1', '%'}}, + args: args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x02, '1', '%'}))}, want: ConnectPayload{}, wantErr: true, }, { name: "指定された長さよりも実際に取得できたClientIDが短い", - args: args{[]byte{0x00, 0x03, '1', '2'}}, + args: args{bufio.NewReader(bytes.NewBuffer([]byte{0x00, 0x03, '1', '2'}))}, want: ConnectPayload{}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ToConnectPayload(tt.args.bs) + got, err := ToConnectPayload(tt.args.r) if (err != nil) != tt.wantErr { t.Errorf("ToConnectPayload() error = %v, wantErr %v", err, tt.wantErr) return
--- a/mqtt/packet/connect_variable_header.go +++ b/mqtt/packet/connect_variable_header.go @@ -1,6 +1,9 @@ package packet import ( + "bufio" + "io" + "github.com/pkg/errors" ) @@ -20,16 +23,34 @@ type ConnectVariableHeader struct { KeepAlive uint16 } -func ToConnectVariableHeader(fixedHeader FixedHeader, bs []byte) (ConnectVariableHeader, error) { +func ToConnectVariableHeader(fixedHeader FixedHeader, r *bufio.Reader) (ConnectVariableHeader, error) { if fixedHeader.PacketType != 1 { return ConnectVariableHeader{}, errors.New("fixedHeader.PacketType must be 1") } - if !isValidProtocolName(bs[:6]) { + protocolName := make([]byte, 6) + _, err := io.ReadFull(r, protocolName) + if err != nil || !isValidProtocolName(protocolName) { return ConnectVariableHeader{}, errors.New("protocol name is invalid") } - if bs[6] != 4 { + protocolLevel, err := r.ReadByte() + if err != nil || protocolLevel != 4 { return ConnectVariableHeader{}, errors.New("protocol level must be 4") } + + // TODO + _, err = r.ReadByte() // connectFlags + if err != nil { + return ConnectVariableHeader{}, err + } + _, err = r.ReadByte() // keepAlive MSB + if err != nil { + return ConnectVariableHeader{}, err + } + _, err = r.ReadByte() // keepAlive LSB + if err != nil { + return ConnectVariableHeader{}, err + } + return ConnectVariableHeader{ ProtocolName: "MQTT", ProtocolLevel: 4,
--- a/mqtt/packet/connect_variable_header_test.go +++ b/mqtt/packet/connect_variable_header_test.go @@ -1,6 +1,8 @@ package packet_test import ( + "bufio" + "bytes" "reflect" "testing" @@ -10,7 +12,7 @@ import ( func TestToConnectVariableHeader(t *testing.T) { type args struct { fixedHeader packet.FixedHeader - bs []byte + r *bufio.Reader } tests := []struct { name string @@ -22,12 +24,12 @@ func TestToConnectVariableHeader(t *testing.T) { name: "仕様書のexample", args: args{ fixedHeader: packet.FixedHeader{PacketType: 1}, - bs: []byte{ + r: bufio.NewReader(bytes.NewBuffer([]byte{ 0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name 0x04, // Protocol Level 0xCE, // Connect Flags 0x00, 0x0A, // Keep Alive - }, + })), }, want: packet.ConnectVariableHeader{ ProtocolName: "MQTT", @@ -41,12 +43,12 @@ func TestToConnectVariableHeader(t *testing.T) { name: "固定ヘッダーのPacketTypeが1ではない", args: args{ fixedHeader: packet.FixedHeader{PacketType: 2}, - bs: []byte{ + r: bufio.NewReader(bytes.NewReader([]byte{ 0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name 0x04, // Protocol Level 0xCE, // Connect Flags 0x00, 0x0A, // Keep Alive - }, + })), }, want: packet.ConnectVariableHeader{}, wantErr: true, @@ -55,12 +57,12 @@ func TestToConnectVariableHeader(t *testing.T) { name: "Protocol Nameが不正", args: args{ fixedHeader: packet.FixedHeader{PacketType: 1}, - bs: []byte{ + r: bufio.NewReader(bytes.NewReader([]byte{ 0x00, 0x04, 'M', 'Q', 'T', 't', // Protocol Name 0x04, // Protocol Level 0xCE, // Connect Flags 0x00, 0x0A, // Keep Alive - }, + })), }, want: packet.ConnectVariableHeader{}, wantErr: true, @@ -69,12 +71,12 @@ func TestToConnectVariableHeader(t *testing.T) { name: "Protocol Levelが不正", args: args{ fixedHeader: packet.FixedHeader{PacketType: 1}, - bs: []byte{ + r: bufio.NewReader(bytes.NewReader([]byte{ 0x00, 0x04, 'M', 'Q', 'T', 'T', // Protocol Name 0x03, // Protocol Level 0xCE, // Connect Flags 0x00, 0x0A, // Keep Alive - }, + })), }, want: packet.ConnectVariableHeader{}, wantErr: true, @@ -82,7 +84,7 @@ func TestToConnectVariableHeader(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := packet.ToConnectVariableHeader(tt.args.fixedHeader, tt.args.bs) + got, err := packet.ToConnectVariableHeader(tt.args.fixedHeader, tt.args.r) if (err != nil) != tt.wantErr { t.Errorf("ToConnectVariableHeader() error = %v, wantErr %v", err, tt.wantErr) return
--- a/mqtt/packet/fixed_header.go +++ b/mqtt/packet/fixed_header.go @@ -1,6 +1,8 @@ package packet -import "github.com/pkg/errors" +import ( + "bufio" +) type FixedHeader struct { PacketType byte @@ -20,17 +22,20 @@ func (h FixedHeader) ToBytes() []byte { return result } -func ToFixedHeader(bs []byte) (FixedHeader, error) { - if len(bs) <= 1 { - return FixedHeader{}, errors.New("len(bs) should be greater than 1") +func ToFixedHeader(r *bufio.Reader) (FixedHeader, error) { + b, err := r.ReadByte() + if err != nil { + return FixedHeader{}, err } - b := bs[0] packetType := b >> 4 - dup := refbit(bs[0], 3) - qos1 := refbit(bs[0], 2) - qos2 := refbit(bs[0], 1) - retain := refbit(bs[0], 0) - remainingLength := decodeRemainingLength(bs[1:]) + dup := refbit(b, 3) + qos1 := refbit(b, 2) + qos2 := refbit(b, 1) + retain := refbit(b, 0) + remainingLength, err := decodeRemainingLength(r) + if err != nil { + return FixedHeader{}, err + } return FixedHeader{ PacketType: packetType, Dup: dup, @@ -46,12 +51,15 @@ func refbit(b byte, n uint) byte { } -func decodeRemainingLength(bs []byte) uint { +func decodeRemainingLength(r *bufio.Reader) (uint, error) { multiplier := uint(1) var value uint i := uint(0) for ; i < 8; i++ { - b := bs[i] + b, err := r.ReadByte() + if err != nil { + return 0, err + } digit := b value = value + uint(digit&127)*multiplier multiplier = multiplier * 128 @@ -59,7 +67,7 @@ func decodeRemainingLength(bs []byte) uint { break } } - return value + return value, nil } func encodeRemainingLength(x uint) []byte {
--- a/mqtt/packet/fixed_header_test.go +++ b/mqtt/packet/fixed_header_test.go @@ -1,7 +1,8 @@ package packet_test import ( - "fmt" + "bufio" + "bytes" "reflect" "testing" @@ -10,51 +11,57 @@ import ( func TestToFixedHeader(t *testing.T) { type args struct { - bs []byte + r *bufio.Reader } tests := []struct { + name string args args want packet.FixedHeader wantErr bool }{ { - args: args{[]byte{ + name: "[0x00,0x00]", + args: args{bufio.NewReader(bytes.NewBuffer([]byte{ 0x00, // 0000 0 00 0 0x00, // 0 - }}, + }))}, want: packet.FixedHeader{PacketType: 0, Dup: 0, QoS1: 0, QoS2: 0, Retain: 0, RemainingLength: 0}, wantErr: false, }, { - args: args{[]byte{ + name: "[0x1b,0x7F]", + args: args{bufio.NewReader(bytes.NewBuffer([]byte{ 0x1B, // 0001 1 01 1 0x7F, // 127 - }}, + }))}, want: packet.FixedHeader{PacketType: 1, Dup: 1, QoS1: 0, QoS2: 1, Retain: 1, RemainingLength: 127}, wantErr: false, }, { - args: args{[]byte{ + name: "[0x24,0x80,0x01]", + args: args{bufio.NewReader(bytes.NewBuffer([]byte{ 0x24, // 0002 0 10 0 0x80, 0x01, //128 - }}, + }))}, want: packet.FixedHeader{PacketType: 2, Dup: 0, QoS1: 1, QoS2: 0, Retain: 0, RemainingLength: 128}, wantErr: false, }, { - args: args{nil}, + name: "[]", + args: args{bufio.NewReader(bytes.NewBuffer(nil))}, want: packet.FixedHeader{}, wantErr: true, }, { - args: args{[]byte{0x24}}, + name: "[0x24]", + args: args{bufio.NewReader(bytes.NewBuffer([]byte{0x24}))}, want: packet.FixedHeader{}, wantErr: true, }, } for _, tt := range tests { - t.Run(fmt.Sprintf("%#v", tt.args.bs), func(t *testing.T) { - got, err := packet.ToFixedHeader(tt.args.bs) + t.Run(tt.name, func(t *testing.T) { + got, err := packet.ToFixedHeader(tt.args.r) if (err != nil) != tt.wantErr { t.Errorf("ToFixedHeader() error = %v, wantErr %v", err, tt.wantErr) return