@@ -4,12 +4,16 @@ import (
4
4
"bufio"
5
5
"fmt"
6
6
"os"
7
+ "reflect"
7
8
"runtime"
8
9
"slices"
9
10
"strings"
11
+ "testing"
10
12
11
13
"github.com/go-git/go-billy/v5/osfs"
12
14
"github.com/go-git/go-billy/v5/util"
15
+ "github.com/stretchr/testify/assert"
16
+ "github.com/stretchr/testify/require"
13
17
"golang.org/x/crypto/ssh"
14
18
"golang.org/x/crypto/ssh/testdata"
15
19
@@ -317,3 +321,100 @@ func (*SuiteCommon) TestNewKnownHostsDbWithCert(c *C) {
317
321
}
318
322
}
319
323
}
324
+
325
+ func TestHostKeyCallbackHelper (t * testing.T ) {
326
+ cb1 := ssh .FixedHostKey (nil )
327
+ tests := []struct {
328
+ name string
329
+ cb ssh.HostKeyCallback
330
+ algos []string
331
+ fallback func (files ... string ) (ssh.HostKeyCallback , error )
332
+ cc * ssh.ClientConfig
333
+ want * ssh.ClientConfig
334
+ wantErr string
335
+ }{
336
+ {
337
+ name : "keep existing callback if set" ,
338
+ cb : cb1 ,
339
+ cc : & ssh.ClientConfig {},
340
+ want : & ssh.ClientConfig {
341
+ HostKeyCallback : cb1 ,
342
+ },
343
+ },
344
+ {
345
+ name : "create new client config is one isn't provided" ,
346
+ cb : cb1 ,
347
+ cc : nil ,
348
+ want : & ssh.ClientConfig {
349
+ HostKeyCallback : cb1 ,
350
+ },
351
+ },
352
+ {
353
+ name : "respect pre-set algos" ,
354
+ cb : cb1 ,
355
+ algos : []string {"foo" },
356
+ cc : & ssh.ClientConfig {},
357
+ want : & ssh.ClientConfig {
358
+ HostKeyCallback : cb1 ,
359
+ HostKeyAlgorithms : []string {"foo" },
360
+ },
361
+ },
362
+ {
363
+ name : "no callback is set, call fallback" ,
364
+ cc : & ssh.ClientConfig {},
365
+ fallback : func (files ... string ) (ssh.HostKeyCallback , error ) {
366
+ return cb1 , nil
367
+ },
368
+ want : & ssh.ClientConfig {
369
+ HostKeyCallback : cb1 ,
370
+ },
371
+ },
372
+ {
373
+ name : "no callback is set with nil client config" ,
374
+ fallback : func (files ... string ) (ssh.HostKeyCallback , error ) {
375
+ return cb1 , nil
376
+ },
377
+ want : & ssh.ClientConfig {
378
+ HostKeyCallback : cb1 ,
379
+ },
380
+ },
381
+ {
382
+ name : "algos with no callback, call fallback" ,
383
+ algos : []string {"bar" },
384
+ cc : & ssh.ClientConfig {},
385
+ fallback : func (files ... string ) (ssh.HostKeyCallback , error ) {
386
+ return cb1 , nil
387
+ },
388
+ want : & ssh.ClientConfig {
389
+ HostKeyCallback : cb1 ,
390
+ HostKeyAlgorithms : []string {"bar" },
391
+ },
392
+ },
393
+ }
394
+
395
+ for _ , tc := range tests {
396
+ t .Run (tc .name , func (t * testing.T ) {
397
+ helper := HostKeyCallbackHelper {
398
+ HostKeyCallback : tc .cb ,
399
+ HostKeyAlgorithms : tc .algos ,
400
+ fallback : tc .fallback ,
401
+ }
402
+
403
+ got , gotErr := helper .SetHostKeyCallback (tc .cc )
404
+
405
+ if tc .wantErr == "" {
406
+ require .NoError (t , gotErr )
407
+ require .NotNil (t , got )
408
+
409
+ wantFunc := runtime .FuncForPC (reflect .ValueOf (tc .want .HostKeyCallback ).Pointer ()).Name ()
410
+ gotFunc := runtime .FuncForPC (reflect .ValueOf (got .HostKeyCallback ).Pointer ()).Name ()
411
+ assert .Equal (t , wantFunc , gotFunc )
412
+
413
+ assert .Equal (t , tc .want .HostKeyAlgorithms , got .HostKeyAlgorithms )
414
+ } else {
415
+ assert .ErrorContains (t , gotErr , tc .wantErr )
416
+ assert .Nil (t , got )
417
+ }
418
+ })
419
+ }
420
+ }
0 commit comments