8000 ensure packet responses in same order as requests · etherscan-io/sftp@7f7e75b · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f7e75b

Browse files
committed
ensure packet responses in same order as requests
Previous code used the request ids to do ordering. This worked until a client came along that used un-ordered request ids. This reworks the ordering to use an internal counter (per session) to order all packets ensuring that responses are sent in the same order as the requests were received. Fixes pkg#260
1 parent 7ef932e commit 7f7e75b

File tree

4 files changed

+154
-111
lines changed

4 files changed

+154
-111
lines changed

packet-manager.go

Lines changed: 86 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,62 +7,87 @@ import (
77
)
88

99
// The goal of the packetManager is to keep the outgoing packets in the same
10-
// order as the incoming. This is due to some sftp clients requiring this
11-
// behavior (eg. winscp).
10+
// order as the incoming as is requires by section 7 of the RFC.
1211

13-
type packetSender interface {
14-
sendPacket(encoding.BinaryMarshaler) error
12+
type packetManager struct {
13+
requests chan orderedPacket
14+
responses chan orderedPacket
15+
fini chan struct{}
16+
incoming orderedPackets
17+
outgoing orderedPackets
18+
sender packetSender // connection object
19+
working *sync.WaitGroup
20+
packetCount uint32
1521
}
1622

17-
type packetManager struct {
18-
requests chan requestPacket
19-
responses chan responsePacket
20-
fini chan struct{}
21-
incoming requestPacketIDs
22-
outgoing responsePackets
23-
sender packetSender // connection object
24-
working *sync.WaitGroup
23+
type packetSender interface {
24+
sendPacket(encoding.BinaryMarshaler) error
2525
}
2626

2727
func newPktMgr(sender packetSender) *packetManager {
2828
s := &packetManager{
29-
requests: make(chan requestPacket, SftpServerWorkerCount),
30-
responses: make(chan responsePacket, SftpServerWorkerCount),
29+
requests: make(chan orderedPacket, SftpServerWorkerCount),
30+
responses: make(chan orderedPacket, SftpServerWorkerCount),
3131
fini: make(chan struct{}),
32-
incoming: make([]uint32, 0, SftpServerWorkerCount),
33-
outgoing: make([]responsePacket, 0, SftpServerWorkerCount),
32+
incoming: make([]orderedPacket, 0, SftpServerWorkerCount),
33+
outgoing: make([]orderedPacket, 0, SftpServerWorkerCount),
3434
sender: sender,
3535
working: &sync.WaitGroup{},
3636
}
3737
go s.controller()
3838
return s
3939
}
4040

41-
type responsePackets []responsePacket
41+
//// packet ordering
42+
func (s *packetManager) newOrderId() uint32 {
43+
s.packetCount++
44+
return s.packetCount
45+
}
4246

43-
func (r responsePackets) Sort() {
44-
sort.Slice(r, func(i, j int) bool {
45-
return r[i].id() < r[j].id()
46-
})
47+
type orderedRequest struct {
48+
requestPacket
49+
orderid uint32
4750
}
4851

49-
type requestPacketIDs []uint32
52+
func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
53+
return orderedRequest{requestPacket: p, orderid: s.newOrderId()}
54+
}
55+
func (p orderedRequest) orderId() uint32 { return p.orderid }
56+
func (p orderedRequest) setOrderId(oid uint32) { p.orderid = oid }
5057

51-
func (r requestPacketIDs) Sort() {
52-
sort.Slice(r, func(i, j int) bool {
53-
return r[i] < r[j]
58+
type orderedResponse struct {
59+
responsePacket
60+
orderid uint32
61+
}
62+
63+
func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
64+
) orderedResponse {
65+
return orderedResponse{responsePacket: p, orderid: id}
66+
}
67+
func (p orderedResponse) orderId() uint32 { return p.orderid }
68+
func (p orderedResponse) setOrderId(oid uint32) { p.orderid = oid }
69+
70+
type orderedPacket interface {
71+
id() uint32
72+
orderId() uint32
73+
}
74+
type orderedPackets []orderedPacket
75+
76+
func (o orderedPackets) Sort() {
77+
sort.Slice(o, func(i, j int) bool {
78+
return o[i].orderId() < o[j].orderId()
5479
})
5580
}
5681

82+
//// packet registry
5783
// register incoming packets to be handled
58-
// send id of 0 for packets without id
59-
func (s *packetManager) incomingPacket(pkt requestPacket) {
84+
func (s *packetManager) incomingPacket(pkt orderedRequest) {
6085
s.working.Add(1)
61-
s.requests <- pkt // buffer == SftpServerWorkerCount
86+
s.requests <- pkt
6287
}
6388

6489
// register outgoing packets as being ready
65-
func (s *packetManager) readyPacket(pkt responsePacket) {
90+
func (s *packetManager) readyPacket(pkt orderedResponse) {
6691
s.responses <- pkt
6792
s.working.Done()
6893
}
@@ -75,27 +100,26 @@ func (s *packetManager) close() {
75100
}
76101

77102
// Passed a worker function, returns a channel for incoming packets.
78-
// The goal is to process packets in the order they are received as is
79-
// requires by section 7 of the RFC, while maximizing throughput of file
80-
// transfers.
81-
func (s *packetManager) workerChan(runWorker func(chan requestPacket),
82-
) chan requestPacket {
103+
// Keep process packet responses in the order they are received while
104+
// maximizing throughput of file transfers.
105+
func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
106+
) chan orderedRequest {
83107

84-
rwChan := make(chan requestPacket, SftpServerWorkerCount)
108+
rwChan := make(chan orderedRequest, SftpServerWorkerCount)
85109
for i := 0; i < SftpServerWorkerCount; i++ {
86110
runWorker(rwChan)
87111
}
88112

89-
cmdChan := make(chan requestPacket)
113+
cmdChan := make(chan orderedRequest)
90114
runWorker(cmdChan)
91115

92-
pktChan := make(chan requestPacket, SftpServerWorkerCount)
116+
pktChan := make(chan orderedRequest, SftpServerWorkerCount)
93117
go func() {
94118
// start with cmdChan
95119
curChan := cmdChan
96120
for pkt := range pktChan {
97121
// on file open packet, switch to rwChan
98-
switch pkt.(type) {
122+
switch pkt.requestPacket.(type) {
99123
case *sshFxpOpenPacket:
100124
curChan = rwChan
101125
// on file close packet, switch back to cmdChan
@@ -122,17 +146,13 @@ func (s *packetManager) controller() {
122146
for {
123147
select {
124148
case pkt := <-s.requests:
125-
debug("incoming id: %v", pkt.id())
126-
s.incoming = append(s.incoming, pkt.id())
127-
if len(s.incoming) > 1 {
128-
s.incoming.Sort()
129-
}
149+
debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderId())
150+
s.incoming = append(s.incoming, pkt)
151+
s.incoming.Sort()
130152
case pkt := <-s.responses:
131-
debug("outgoing pkt: %v", pkt.id())
153+
debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderId())
132154
s.outgoing = append(s.outgoing, pkt)
133-
if len(s.outgoing) > 1 {
134-
s.outgoing.Sort()
135-
}
155+
s.outgoing.Sort()
136156
case <-s.fini:
137157
return
138158
}
@@ -150,10 +170,11 @@ func (s *packetManager) maybeSendPackets() {
150170
}
151171
out := s.outgoing[0]
152172
in := s.incoming[0]
153-
// debug("incoming: %v", s.incoming)
154-
// debug("outgoing: %v", outfilter(s.outgoing))
155-
if in == out.id() {
156-
s.sender.sendPacket(out)
173+
// debug("incoming: %v", ids(s.incoming))
174+
// debug("outgoing: %v", ids(s.outgoing))
175+
if in.orderId() == out.orderId() {
176+
debug("Sending packet: %v", out.id())
177+
s.sender.sendPacket(out.(encoding.BinaryMarshaler))
157178
// pop off heads
158179
copy(s.incoming, s.incoming[1:]) // shift left
159180
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
@@ -165,10 +186,17 @@ func (s *packetManager) maybeSendPackets() {
165186
}
166187
}
167188

168-
//func outfilter(o []responsePacket) []uint32 {
169-
// res := make([]uint32, 0, len(o))
170-
// for _, v := range o {
171-
// res = append(res, v.id())
172-
// }
173-
// return res
174-
//}
189+
// func oids(o []orderedPacket) []uint32 {
190+
// res := make([]uint32, 0, len(o))
191+
// for _, v := range o {
192+
// res = append(res, v.orderId())
193+
// }
194+
// return res
195+
// }
196+
// func ids(o []orderedPacket) []uint32 {
197+
// res := make([]uint32, 0, len(o))
198+
// for _, v := range o {
199+
// res = append(res, v.id())
200+
// }
201+
// return res
202+
// }

packet-manager_test.go

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@ func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error {
2121
return nil
2222
}
2323

24-
type fakepacket uint32
24+
type fakepacket struct {
25+
reqid uint32
26+
oid uint32
27+
}
28+
29+
func fake(rid, order uint32) fakepacket {
30+
return fakepacket{reqid: rid, oid: order}
31+
}
2532

2633
func (fakepacket) MarshalBinary() ([]byte, error) {
2734
return []byte{}, nil
@@ -32,71 +39,89 @@ func (fakepacket) UnmarshalBinary([]byte) error {
3239
}
3340

3441
func (f fakepacket) id() uint32 {
35-
return uint32(f)
42+
return f.reqid
3643
}
3744

3845
type pair struct {
39-
in fakepacket
40-
out fakepacket
46+
in, out fakepacket
47+
}
48+
49+
type ordered_pair struct {
50+
in orderedRequest
51+
out orderedResponse
4152
}
4253

4354
// basic test
4455
var ttable1 = []pair{
45-
pair{fakepacket(0), fakepacket(0)},
46-
pair{fakepacket(1), fakepacket(1)},
47-
pair{fakepacket(2), fakepacket(2)},
48-
pair{fakepacket(3), fakepacket(3)},
56+
pair{fake(0, 0), fake(0, 0)},
57+
pair{fake(1, 1), fake(1, 1)},
58+
pair{fake(2, 2), fake(2, 2)},
59+
pair{fake(3, 3), fake(3, 3)},
4960
}
5061

5162
// outgoing packets out of order
5263
var ttable2 = []pair{
53-
pair{fakepacket(0), fakepacket(0)},
54-
pair{fakepacket(1), fakepacket(4)},
55-
pair{fakepacket(2), fakepacket(1)},
56-
pair{fakepacket(3), fakepacket(3)},
57-
pair{fakepacket(4), fakepacket(2)},
64+
pair{fake(10, 0), fake(12, 2)},
65+
pair{fake(11, 1), fake(11, 1)},
66+
pair{fake(12, 2), fake(13, 3)},
67+
pair{fake(13, 3), fake(10, 0)},
5868
}
5969

60-
// incoming packets out of order
70+
// request ids are not incremental
6171
var ttable3 = []pair{
62-
pair{fakepacket(2), fakepacket(0)},
63-
pair{fakepacket(1), fakepacket(1)},
64-
pair{fakepacket(3), fakepacket(2)},
65-
pair{fakepacket(0), fakepacket(3)},
72+
pair{fake(7, 0), fake(7, 0)},
73+
pair{fake(1, 1), fake(1, 1)},
74+
pair{fake(9, 2), fake(3, 3)},
75+
pair{fake(3, 3), fake(9, 2)},
6676
}
6777

68-
var tables = [][]pair{ttable1, ttable2, ttable3}
78+
// request ids are all the same
79+
var ttable4 = []pair{
80+
pair{fake(1, 0), fake(1, 0)},
81+
pair{fake(1, 1), fake(1, 1)},
82+
pair{fake(1, 2), fake(1, 3)},
83+
pair{fake(1, 3), fake(1, 2)},
84+
}
85+
86+
var tables = [][]pair{ttable1, ttable2, ttable3, ttable4}
6987

7088
func TestPacketManager(t *testing.T) {
7189
sender := newTestSender()
7290
s := newPktMgr(sender)
7391

7492
for i := range tables {
7593
table := tables[i]
94+
ordered_pairs := make([]ordered_pair, 0, len(table))
7695
for _, p := range table {
96+
ordered_pairs = append(ordered_pairs, ordered_pair{
97+
in: orderedRequest{p.in, p.in.oid},
98+
out: orderedResponse{p.out, p.out.oid},
99+
})
100+
}
101+
for _, p := range ordered_pairs {
77102
s.incomingPacket(p.in)
78103
}
79-
for _, p := range table {
104+
for _, p := range ordered_pairs {
80105
s.readyPacket(p.out)
81106
}
82-
for i := 0; i < len(table); i++ {
107+
for _, p := range table {
83108
pkt := <-sender.sent
84-
id := pkt.(fakepacket).id()
85-
assert.Equal(t, id, uint32(i))
109+
id := pkt.(orderedResponse).id()
110+
assert.Equal(t, id, p.in.id())
86111
}
87112
}
88113
s.close()
89114
}
90115

91116
func (p sshFxpRemovePacket) String() string {
92-
return fmt.Sprintf("RmPct:%d", p.ID)
117+
return fmt.Sprintf("RmPkt:%d", p.ID)
93118
}
94119
func (p sshFxpOpenPacket) String() string {
95-
return fmt.Sprintf("OpPct:%d", p.ID)
120+
return fmt.Sprintf("OpPkt:%d", p.ID)
96121
}
97122
func (p sshFxpWritePacket) String() string {
98-
return fmt.Sprintf("WrPct:%d", p.ID)
123+
return fmt.Sprintf("WrPkt:%d", p.ID)
99124
}
100125
func (p sshFxpClosePacket) String() string {
101-
return fmt.Sprintf("ClPct:%d", p.ID)
126+
return fmt.Sprintf("ClPkt:%d", p.ID)
102127
}

request-server.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func (rs *RequestServer) Serve() error {
105105
ctx, cancel := context.WithCancel(context.Background())
106106
defer cancel()
107107
var wg sync.WaitGroup
108-
runWorker := func(ch chan requestPacket) {
108+
runWorker := func(ch chan orderedRequest) {
109109
wg.Add(1)
110110
go func() {
111111
defer wg.Done()
@@ -142,7 +142,7 @@ func (rs *RequestServer) Serve() error {
142142
}
143143
}
144144

145-
pktChan <- pkt
145+
pktChan <- rs.pktMgr.newOrderedRequest(pkt)
146146
}
147147

148148
close(pktChan) // shuts down sftpServerWorkers
@@ -159,11 +159,11 @@ func (rs *RequestServer) Serve() error {
159159
}
160160

161161
func (rs *RequestServer) packetWorker(
162-
ctx context.Context, pktChan chan requestPacket,
162+
ctx context.Context, pktChan chan orderedRequest,
163163
) error {
164164
for pkt := range pktChan {
165165
var rpkt responsePacket
166-
switch pkt := pkt.(type) {
166+
switch pkt := pkt.requestPacket.(type) {
167167
case *sshFxInitPacket:
168168
rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
169169
case *sshFxpClosePacket:
@@ -208,7 +208,8 @@ func (rs *RequestServer) packetWorker(
208208
return errors.Errorf("unexpected packet type %T", pkt)
209209
}
210210

211-
rs.sendPacket(rpkt)
211+
rs.pktMgr.readyPacket(
212+
rs.pktMgr.newOrderedResponse(rpkt, pkt.orderId()))
212213
}
213214
return nil
214215
}
@@ -240,8 +241,3 @@ func cleanPath(p string) string {
240241
}
241242
return path.Clean(p)
242243
}
243-
244-
// Wrap underlying connection methods to use packetManager
245-
func (rs *RequestServer) sendPacket(pkt responsePacket) {
246-
rs.pktMgr.readyPacket(pkt)
247-
}

0 commit comments

Comments
 (0)
0