8000 Merge pull request #263 from pkg/packet-embedded-in-ordered · etherscan-io/sftp@6b9fa10 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6b9fa10

Browse files
authored
Merge pull request pkg#263 from pkg/packet-embedded-in-ordered
ensure packet responses in same order as requests Fixes pkg#260
2 parents 7ef932e + 7f7e75b commit 6b9fa10

File tree

4 files changed

+154
-111
lines changed

4 files changed

+154
-111
lines changed

packet-manager.go

Lines changed: 86 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,62 +7,87 @@ import (
77
)
88

99
// The goal of the packetManager is to keep the outgoing packets in the same
10-
// order as the incoming. This is due to some sftp clients requiring this
11-
// behavior (eg. winscp).
10+
// order as the incoming as is requires by section 7 of the RFC.
1211

13-
type packetSender interface {
14-
sendPacket(encoding.BinaryMarshaler) error
12+
type packetManager struct {
13+
requests chan orderedPacket
14+
responses chan orderedPacket
15+
fini chan struct{}
16+
incoming orderedPackets
17+
outgoing orderedPackets
18+
sender packetSender // connection object
19+
working *sync.WaitGroup
20+
packetCount uint32
1521
}
1622

17-
type packetManager struct {
18-
requests chan requestPacket
19-
responses chan responsePacket
20-
fini chan struct{}
21-
incoming requestPacketIDs
22-
outgoing responsePackets
23-
sender packetSender // connection object
24-
working *sync.WaitGroup
23+
type packetSender interface {
24+
sendPacket(encoding.BinaryMarshaler) error
2525
}
2626

2727
func newPktMgr(sender packetSender) *packetManager {
2828
s := &packetManager{
29-
requests: make(chan requestPacket, SftpServerWorkerCount),
30-
responses: make(chan responsePacket, SftpServerWorkerCount),
29+
requests: make(chan orderedPacket, SftpServerWorkerCount),
30+
responses: make(chan orderedPacket, SftpServerWorkerCount),
3131
fini: make(chan struct{}),
32-
incoming: make([]uint32, 0, SftpServerWorkerCount),
33-
outgoing: make([]responsePacket, 0, SftpServerWorkerCount),
32+
incoming: make([]orderedPacket, 0, SftpServerWorkerCount),
33+
outgoing: make([]orderedPacket, 0, SftpServerWorkerCount),
3434
sender: sender,
3535
working: &sync.WaitGroup{},
3636
}
3737
go s.controller()
3838
return s
3939
}
4040

41-
type responsePackets []responsePacket
41+
//// packet ordering
42+
func (s *packetManager) newOrderId() uint32 {
43+
s.packetCount++
44+
return s.packetCount
45+
}
4246

43-
func (r responsePackets) Sort() {
44-
sort.Slice(r, func(i, j int) bool {
45-
return r[i].id() < r[j].id()
46-
})
47+
type orderedRequest struct {
48+
requestPacket
49+
orderid uint32
4750
}
4851

49-
type requestPacketIDs []uint32
52+
func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
53+
return orderedRequest{requestPacket: p, orderid: s.newOrderId()}
54+
}
55+
func (p orderedRequest) orderId() uint32 { return p.orderid }
56+
func (p orderedRequest) setOrderId(oid uint32) { p.orderid = oid }
5057

51-
func (r requestPacketIDs) Sort() {
52-
sort.Slice(r, func(i, j int) bool {
53-
return r[i] < r[j]
58+
type orderedResponse struct {
59+
responsePacket
60+
orderid uint32
61+
}
62+
63+
func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
64+
) orderedResponse {
65+
return orderedResponse{responsePacket: p, orderid: id}
66+
}
67+
func (p orderedResponse) orderId() uint32 { return p.orderid }
68+
func (p orderedResponse) setOrderId(oid uint32) { p.orderid = oid }
69+
70+
type orderedPacket interface {
71+
id() uint32
72+
orderId() uint32
73+
}
74+
type orderedPackets []orderedPacket
75+
76+
func (o orderedPackets) Sort() {
77+
sort.Slice(o, func(i, j int) bool {
78+
return o[i].orderId() < o[j].orderId()
5479
})
5580
}
5681

82+
//// packet registry
5783
// register incoming packets to be handled
58-
// send id of 0 for packets without id
59-
func (s *packetManager) incomingPacket(pkt requestPacket) {
84+
func (s *packetManager) incomingPacket(pkt orderedRequest) {
6085
s.working.Add(1)
61-
s.requests <- pkt // buffer == SftpServerWorkerCount
86+
s.requests <- pkt
6287
}
6388

6489
// register outgoing packets as being ready
65-
func (s *packetManager) readyPacket(pkt responsePacket) {
90+
func (s *packetManager) readyPacket(pkt orderedResponse) {
6691
s.responses <- pkt
6792
s.working.Done()
6893
}
@@ -75,27 +100,26 @@ func (s *packetManager) close() {
75100
}
76101

77102
// Passed a worker function, returns a channel for incoming packets.
78-
// The goal is to process packets in the order they are received as is
79-
// requires by section 7 of the RFC, while maximizing throughput of file
80-
// transfers.
81-
func (s *packetManager) workerChan(runWorker func(chan requestPacket),
82-
) chan requestPacket {
103+
// Keep process packet responses in the order they are received while
104+
// maximizing throughput of file transfers.
105+
func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
106+
) chan orderedRequest {
83107

84-
rwChan := make(chan requestPacket, SftpServerWorkerCount)
1 10000 08+
rwChan := make(chan orderedRequest, SftpServerWorkerCount)
85109
for i := 0; i < SftpServerWorkerCount; i++ {
86110
runWorker(rwChan)
87111
}
88112

89-
cmdChan := make(chan requestPacket)
113+
cmdChan := make(chan orderedRequest)
90114
runWorker(cmdChan)
91115

92-
pktChan := make(chan requestPacket, SftpServerWorkerCount)
116+
pktChan := make(chan orderedRequest, SftpServerWorkerCount)
93117
go func() {
94118
// start with cmdChan
95119
curChan := cmdChan
96120
for pkt := range pktChan {
97121
// on file open packet, switch to rwChan
98-
switch pkt.(type) {
122+
switch pkt.requestPacket.(type) {
99123
case *sshFxpOpenPacket:
100124
curChan = rwChan
101125
// on file close packet, switch back to cmdChan
@@ -122,17 +146,13 @@ func (s *packetManager) controller() {
122146
for {
123147
select {
124148
case pkt := <-s.requests:
125-
debug("incoming id: %v", pkt.id())
126-
s.incoming = append(s.incoming, pkt.id())
127-
if len(s.incoming) > 1 {
128-
s.incoming.Sort()
129-
}
149+
debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderId())
150+
s.incoming = append(s.incoming, pkt)
151+
s.incoming.Sort()
130152
case pkt := <-s.responses:
131-
debug("outgoing pkt: %v", pkt.id())
153+
debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderId())
132154
s.outgoing = append(s.outgoing, pkt)
133-
if len(s.outgoing) > 1 {
134-
s.outgoing.Sort()
135-
}
155+
s.outgoing.Sort()
136156
case <-s.fini:
137157
return
138158
}
@@ -150,10 +170,11 @@ func (s *packetManager) maybeSendPackets() {
150170
}
151171
out := s.outgoing[0]
152172
in := s.incoming[0]
153-
// debug("incoming: %v", s.incoming)
154-
// debug("outgoing: %v", outfilter(s.outgoing))
155-
if in == out.id() {
156-
s.sender.sendPacket(out)
173+
// debug("incoming: %v", ids(s.incoming))
174+
// debug("outgoing: %v", ids(s.outgoing))
175+
if in.orderId() == out.orderId() {
176+
debug("Sending packet: %v", out.id())
177+
s.sender.sendPacket(out.(encoding.BinaryMarshaler))
157178
// pop off heads
158179
copy(s.incoming, s.incoming[1:]) // shift left
159180
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
@@ -165,10 +186,17 @@ func (s *packetManager) maybeSendPackets() {
165186
}
166187
}
167188

168-
//func outfilter(o []responsePacket) []uint32 {
169-
// res := make([]uint32, 0, len(o))
170-
// for _, v := range o {
171-
// res = append(res, v.id())
172-
// }
173-
// return res
174-
//}
189+
// func oids(o []orderedPacket) []uint32 {
190+
// res := make([]uint32, 0, len(o))
191+
// for _, v := range o {
192+
// res = append(res, v.orderId())
193+
// }
194+
// return res
195+
// }
196+
// func ids(o []orderedPacket) []uint32 {
197+
// res := make([]uint32, 0, len(o))
198+
// for _, v := range o {
199+
// res = append(res, v.id())
200+
// }
201+
// return res
202+
// }

packet-manager_test.go

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@ func (s _testSender) sendPacket(p encoding.BinaryMarshaler) error {
2121
return nil
2222
}
2323

24-
type fakepacket uint32
24+
type fakepacket struct {
25+
reqid uint32
26+
oid uint32
27+
}
28+
29+
func fake(rid, order uint32) fakepacket {
30+
return fakepacket{reqid: rid, oid: order}
31+
}
2532

2633
func (fakepacket) MarshalBinary() ([]byte, error) {
2734
return []byte{}, nil
@@ -32,71 +39,89 @@ func (fakepacket) UnmarshalBinary([]byte) error {
3239
}
3340

3441
func (f fakepacket) id() uint32 {
35-
return uint32(f)
42+
return f.reqid
3643
}
3744

3845
type pair struct {
39-
in fakepacket
40-
out fakepacket
46+
in, out fakepacket
47+
}
48+
49+
type ordered_pair struct {
50+
in orderedRequest
51+
out orderedResponse
4152
}
4253

4354
// basic test
4455
var ttable1 = []pair{
45-
pair{fakepacket(0), fakepacket(0)},
46-
pair{fakepacket(1), fakepacket(1)},
47-
pair{fakepacket(2), fakepacket(2)},
48-
pair{fakepacket(3), fakepacket(3)},
56+
pair{fake(0, 0), fake(0, 0)},
57+
pair{fake(1, 1), fake(1, 1)},
58+
pair{fake(2, 2), fake(2, 2)},
59+
pair{fake(3, 3), fake(3, 3)},
4960
}
5061

5162
// outgoing packets out of order
5263
var ttable2 = []pair{
53-
pair{fakepacket(0), fakepacket(0)},
54-
pair{fakepacket(1), fakepacket(4)},
55-
pair{fakepacket(2), fakepacket(1)},
56-
pair{fakepacket(3), fakepacket(3)},
57-
pair{fakepacket(4), fakepacket(2)},
64+
pair{fake(10, 0), fake(12, 2)},
65+
pair{fake(11, 1), fake(11, 1)},
66+
pair{fake(12, 2), fake(13, 3)},
67+
pair{fake(13, 3), fake(10, 0)},
5868
}
5969

60-
// incoming packets out of order
70+
// request ids are not incremental
6171
var ttable3 = []pair{
62-
pair{fakepacket(2), fakepacket(0)},
63-
pair{fakepacket(1), fakepacket(1)},
64-
pair{fakepacket(3), fakepacket(2)},
65-
pair{fakepacket(0), fakepacket(3)},
72+
pair{fake(7, 0), fake(7, 0)},
73+
pair{fake(1, 1), fake(1, 1)},
74+
pair{fake(9, 2), fake(3, 3)},
75+
pair{fake(3, 3), fake(9, 2)},
6676
}
6777

68-
var tables = [][]pair{ttable1, ttable2, ttable3}
78+
// request ids are all the same
79+
var ttable4 = []pair{
80+
pair{fake(1, 0), fake(1, 0)},
81+
pair{fake(1, 1), fake(1, 1)},
82+
pair{fake(1, 2), fake(1, 3)},
83+
pair{fake(1, 3), fake(1, 2)},
84+
}
85+
86+
var tables = [][]pair{ttable1, ttable2, ttable3, ttable4}
6987

7088
func TestPacketManager(t *testing.T) {
7189
sender := newTestSender()
7290
s := newPktMgr(sender)
7391

7492
for i := range tables {
7593
table := tables[i]
94+
ordered_pairs := make([]ordered_pair, 0, len(table))
7695
for _, p := range table {
96+
ordered_pairs = append(ordered_pairs, ordered_pair{
97+
in: orderedRequest{p.in, p.in.oid},
98+
out: orderedResponse{p.out, p.out.oid},
99+
})
100+
}
101+
for _, p := range ordered_pairs {
77102
s.incomingPacket(p.in)
78103
}
79-
for _, p := range table {
104+
for _, p := range ordered_pairs {
80105
s.readyPacket(p.out)
81106
}
82-
for i := 0; i < len(table); i++ {
107+
for _, p := range table {
83108
pkt := <-sender.sent
84-
id := pkt.(fakepacket).id()
85-
assert.Equal(t, id, uint32(i))
109+
id := pkt.(orderedResponse).id()
110+
assert.Equal(t, id, p.in.id())
86111
}
87112
}
88113
s.close()
89114
}
90115

91116
func (p sshFxpRemovePacket) String() string {
92-
return fmt.Sprintf("RmPct:%d", p.ID)
117+
return fmt.Sprintf("RmPkt:%d", p.ID)
93118
}
94119
func (p sshFxpOpenPacket) String() string {
95-
return fmt.Sprintf("OpPct:%d", p.ID)
120+
return fmt.Sprintf("OpPkt:%d", p.ID)
96121
}
97122
func (p sshFxpWritePacket) String() string {
98-
return fmt.Sprintf("WrPct:%d", p.ID)
123+
return fmt.Sprintf("WrPkt:%d", p.ID)
99124
}
100125
func (p sshFxpClosePacket) String() string {
101-
return fmt.Sprintf("ClPct:%d", p.ID)
126+
return fmt.Sprintf("ClPkt:%d", p.ID)
102127
}

request-server.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func (rs *RequestServer) Serve() error {
105105
ctx, cancel := context.WithCancel(context.Background())
106106
defer cancel()
107107
var wg sync.WaitGroup
108-
runWorker := func(ch chan requestPacket) {
108+
runWorker := func(ch chan orderedRequest) {
109109
wg.Add(1)
110110
go func() {
111111
defer wg.Done()
@@ -142,7 +142,7 @@ func (rs *RequestServer) Serve() error {
142142
}
143143
}
144144

145-
pktChan <- pkt
145+
pktChan <- rs.pktMgr.newOrderedRequest(pkt)
146146
}
147147

148148
close(pktChan) // shuts down sftpServerWorkers
@@ -159,11 +159,11 @@ func (rs *RequestServer) Serve() error {
159159
}
160160

161161
func (rs *RequestServer) packetWorker(
162-
ctx context.Context, pktChan chan requestPacket,
162+
ctx context.Context, pktChan chan orderedRequest,
163163
) error {
164164
for pkt := range pktChan {
165165
var rpkt responsePacket
166-
switch pkt := pkt.(type) {
166+
switch pkt := pkt.requestPacket.(type) {
167167
case *sshFxInitPacket:
168168
rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
169169
case *sshFxpClosePacket:
@@ -208,7 +208,8 @@ func (rs *RequestServer) packetWorker(
208208
return errors.Errorf("unexpected packet type %T", pkt)
209209
}
210210

211-
rs.sendPacket(rpkt)
211+
rs.pktMgr.readyPacket(
212+
rs.pktMgr.newOrderedResponse(rpkt, pkt.orderId()))
212213
}
213214
return nil
214215
}
@@ -240,8 +241,3 @@ func cleanPath(p string) string {
240241
}
241242
return path.Clean(p)
242243
}
243-
244-
// Wrap underlying connection methods to use packetManager
245-
func (rs *RequestServer) sendPacket(pkt responsePacket) {
246-
rs.pktMgr.readyPacket(pkt)
247-
}

0 commit comments

Comments
 (0)
0