8000 fix printing empty tensors · DiffSharp/DiffSharp@03a38d5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 03a38d5

Browse files
committed
fix printing empty tensors
1 parent 6642c2a commit 03a38d5

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/DiffSharp.Core/RawTensor.fs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,8 @@ type RawTensor() =
622622
member t.Print(?postfix: string) =
623623
// TODO: this code is not ideal and can be reimplemented to be cleaner and more efficient
624624
let postfix = defaultArg postfix ""
625+
if t.Nelement = 0 then sprintf "tensor([])%s" postfix
626+
else
625627
let threshold = Printer.Default.threshold
626628
let edgeItems = Printer.Default.edgeItems
627629
let precision = Printer.Default.precision

tests/DiffSharp.Tests/TestTensor.fs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,11 +1159,13 @@ type TestTensor () =
11591159
[<Test>]
11601160
member _.TestTensorToString () =
11611161
for combo in Combos.IntegralAndFloatingPoint do
1162+
let tempty = combo.tensor([])
11621163
let t0 = combo.tensor(2.)
11631164
let t1 = combo.tensor([[2.]; [2.]])
11641165
let t2 = combo.tensor([[[2.; 2.]]])
11651166
let t3 = combo.tensor([[1.;2.]; [3.;4.]])
11661167
let t4 = combo.tensor([[[[1.]]]])
1168+
let temptyString = tempty.ToString()
11671169
let t0String = t0.ToString()
11681170
let t1String = t1.ToString()
11691171
let t2String = t2.ToString()
@@ -1198,11 +1200,13 @@ type TestTensor () =
11981200
sprintf ",backend=%s" (combo.backend.ToString())
11991201

12001202
let extraText = dtypeText + deviceText + backendText
1203+
let temptyStringCorrect = "tensor([])"
12011204
let t0StringCorrect = sprintf "tensor(2%s%s)" suffix extraText
12021205
let t1StringCorrect = sprintf "tensor([[2%s],\n [2%s]]%s)" suffix suffix extraText
12031206
let t2StringCorrect = sprintf "tensor([[[2%s, 2%s]]]%s)" suffix suffix extraText
12041207
let t3StringCorrect = sprintf "tensor([[1%s, 2%s],\n [3%s, 4%s]]%s)" suffix suffix suffix suffix extraText
12051208
let t4StringCorrect = sprintf "tensor([[[[1%s]]]]%s)" suffix extraText
1209+
Assert.CheckEqual(temptyStringCorrect, temptyString)
12061210
Assert.CheckEqual(t0StringCorrect, t0String)
12071211
Assert.CheckEqual(t1StringCorrect, t1String)
12081212
Assert.CheckEqual(t2StringCorrect, t2String)

0 commit comments

Comments
 (0)
0