8000 Add support for sparse complex tensors for CPU/CUDA by aocsa · Pull Request #50984 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add support for sparse complex tensors for CPU/CUDA #50984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

aocsa
Copy link
Contributor
@aocsa aocsa commented Jan 23, 2021

Fixes #50690

Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

  • add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
  • add complex support to coalesce function
  • add complex support to to_dense function
  • add complex support to to_sparse function
  • add complex support to sparse_add function
  • add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor (fordward/backward) function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra using cuSparse and MKL.

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Jan 23, 2021

💊 CI failures summary and remediations

As of commit 9dfa1d8 (more details on the Dr. CI page):



🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_clang7_onnx_ort_test2 (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_faster_rcnn FAILED [ 31%]
Mar 17 16:00:04 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_embedding_bag_1d_per_sample_weights PASSED [ 29%]
Mar 17 16:00:04 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_embedding_bag_2d_per_sample_weights PASSED [ 29%]
Mar 17 16:00:04 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_embedding_bag_dynamic_input SKIPPED [ 29%]
Mar 17 16:00:04 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_embedding_model_with_external_data PASSED [ 30%]
Mar 17 16:00:04 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_embedding_module PASSED [ 30%]
Mar 17 16:00:04 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_empty_branch PASSED [ 30%]
Mar 17 16:00:04 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_expand PASSED [ 30%]
Mar 17 16:00:04 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_eye PASSED [ 30%]
Mar 17 16:00:04 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_fake_quantize_per_channel SKIPPED [ 30%]
Mar 17 16:00:05 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_fake_quantize_per_tensor PASSED [ 30%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_faster_rcnn FAILED [ 31%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_flatten PASSED [ 31%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_flatten2d PASSED [ 31%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_flatten2d_neg PASSED [ 31%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_flatten_dynamic_axes PASSED [ 31%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_flip PASSED [ 31%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_floating_point PASSED [ 32%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_floating_point_infer_dtype PASSED [ 32%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_floor_div PASSED [ 32%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_floor_div_script PASSED [ 32%]
Mar 17 16:01:54 test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime_opset11::test_floordiv PASSED [ 32%]

2 failures not recognized by patterns:

Job Step Action
CircleCI pytorch_linux_xenial_py3_clang7_onnx_ort_test1 Run tests 🔁 rerun
CircleCI pytorch_windows_vs2019_py36_cuda10.1_test2 Checkout code 🔁 rerun

❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_macos_10_13_py3_test (1/1)

Step: "Checkout code" (full log | diagnosis details | 🔁 rerun) ❄️

fatal: the remote end hung up unexpectedly
Using SSH Config Dir '/Users/distiller/.ssh'
git version 2.28.0
Cloning git repository
Cloning into '.'...
Warning: Permanently added the RSA host key for IP address '140.82.114.3' to the list of known hosts.

remote: Enumerating objects: 523728        
remote: Enumerating objects: 181, done.        
remote: Counting objects:   0% (1/181)        
remote: Counting objects:   1% (2/181)        
remote: Counting objects:   2% (4/181)        
remote: Counting objects:   3% (6/181)        
remote: Counting objects:   4% (8/181)        
remote: Counting objects:   5% (10/181)        
remote: Counting objects:   6% (11/181)        
remote: Counting objects:   7% (13/181)        
remote: Counting objects:   8% (15/181)        
remote: Counting objects:   9% (17/181)        
remote: Counting objects:  10% (19/181)        
remote: Counting objects:  11% (20/181)        
remote: Counting objects:  12% (22/181)        
remote: Counting objects:  13% (24/181)        
remote: Counting objects:  14% (26/181)        
remote: Counting objects:  15% (28/181)        
remote: Counting objects:  16% (29/181)        
remote: Counting objects:  17% (31/181)        
remote: Counting objects:  18% (33/181)        
remote: Counting objects:  19% (35/181)        
remote: Counting objects:  20% (37/181)        
remote: Counting objects:  21% (39/181)        
remote: Counting objects:  22% (40/181)        
remote: Counting objects:  23% (42/181)        
remote: Counting objects:  24% (44/181)        
remote: Counting objects:  25% (46/181)        
remote: Counting objects:  26% (48/181)        
remote: Counting objects:  27% (49/181)        
remote: Counting objects:  28% (51/181)        
remote: Counting objects:  29% (53/181)        
remote: Counting objects:  30% (55/181)        
remote: Counting objects:  31% (57/181)        
remote: Counting objects:  32% (58/181)        
remote: Counting objects:  33% (60/181)        
remote: Counting objects:  34% (62/181)        
remote: Counting objects:  35% (64/181)        
remote: Counting objects:  36% (66/181)        
remote: Counting objects:  37% (67/181)        
remote: Counting objects:  38% (69/181)        
remote: Counting objects:  39% (71/181)        
remote: Counting objects:  40% (73/181)        
remote: Counting objects:  41% (75/181)        
remote: Counting objects:  42% (77/181)        
remote: Counting objects:  43% (78/181)        
remote: Counting objects:  44% (80/181)        
remote: Counting objects:  45% (82/181)        
remote: Counting objects:  46% (84/181)        
remote: Counting objects:  47% (86/181)        
remote: Counting objects:  48% (87/181)        
remote: Counting objects:  49% (89/181)        
remote: Counting objects:  50% (91/181)        
remote: Counting objects:  51% (93/181)        
remote: Counting objects:  52% (95/181)        
remote: Counting objects:  53% (96/181)        
remote: Counting objects:  54% (98/181)        
remote: Counting objects:  55% (100/181)        
remote: Counting objects:  56% (102/181)        
remote: Counting objects:  57% (104/181)        
remote: Counting objects:  58% (105/181)        
remote: Counting objects:  59% (107/181)        
remote: Counting objects:  60% (109/181)        
remote: Counting objects:  61% (111/181)        
remote: Counting objects:  62% (113/181)        
remote: Counting objects:  63% (115/181)        
remote: Counting objects:  64% (116/181)        
remote: Counting objects:  65% (118/181)        
remote: Counting objects:  66% (120/181)        
remote: Counting objects:  67% (122/181)        
remote: Counting objects:  68% (124/181)        
remote: Counting objects:  69% (125/181)        
remote: Counting objects:  70% (127/181)        
remote: Counting objects:  71% (129/181)        
remote: Counting objects:  72% (131/181)        
remote: Counting objects:  73% (133/181)        
remote: Counting objects:  74% (134/181)        
remote: Counting objects:  75% (136/181)        
remote: Counting objects:  76% (138/181)        
remote: Counting objects:  77% (140/181)        
remote: Counting objects:  78% (142/181)        
remote: Counting objects:  79% (143/181)        
remote: Counting objects:  80% (145/181)        
remote: Counting objects:  81% (147/181)        
remote: Counting objects:  82% (149/181)        
remote: Counting objects:  83% (151/181)        
remote: Counting objects:  84% (153/181)        
remote: Counting objects:  85% (154/181)        
remote: Counting objects:  86% (156/181)        
remote: Counting objects:  87% (158/181)        
remote: Counting objects:  88% (160/181)        
remote: Counting objects:  89% (162/181)        
remote: Counting objects:  90% (163/181)        
remote: Counting objects:  91% (165/181)        
remote: Counting objects:  92% (167/181)        
remote: Counting objects:  93% (169/181)        
remote: Counting objects:  94% (171/181)        
remote: Counting objects:  95% (172/181)        
remote: Counting objects:  96% (174/181)        
remote: Counting objects:  97% (176/181)        
remote: Counting objects:  98% (178/181)        
remote: Counting objects:  99% (180/181)        
remote: Counting objects: 100% (181/181)        
remote: Counting objects: 100% (181/181), done.        
remote: Compressing objects:   0% (1/157)        
remote: Compressing objects:   1% (2/157)        
remote: Compressing objects:   2% (4/157)        
remote: Compressing objects:   3% (5/157)        
remote: Compressing objects:   4% (7/157)        
remote: Compressing objects:   5% (8/157)        
remote: Compressing objects:   6% (10/157)        
remote: Compressing objects:   7% (11/157)        
remote: Compressing objects:   8% (13/157)        
remote: Compressing objects:   9% (15/157)        
remote: Compressing objects:  10% (16/157)        
remote: Compressing objects:  11% (18/157)        
remote: Compressing objects:  12% (19/157)        
remote: Compressing objects:  13% (21/157)        
remote: Compressing objects:  14% (22/157)        
remote: Compressing objects:  15% (24/157)        
remote: Compressing objects:  16% (26/157)        
remote: Compressing objects:  17% (27/157)        
remote: Compressing objects:  18% (29/157)        
remote: Compressing objects:  19% (30/157)        
remote: Compressing objects:  20% (32/157)        
remote: Compressing objects:  21% (33/157)        
remote: Compressing objects:  22% (35/157)        
remote: Compressing objects:  23% (37/157)        
remote: Compressing objects:  24% (38/157)        
remote: Compressing objects:  25% (40/157)        
remote: Compressing objects:  26% (41/157)        
remote: Compressing objects:  27% (43/157)        
remote: Compressing objects:  28% (44/157)        
remote: Compressing objects:  29% (46/157)        
remote: Compressing objects:  30% (48/157)        
remote: Compressing objects:  31% (49/157)        
remote: Compressing objects:  32% (51/157)        
remote: Compressing objects:  33% (52/157)        
remote: Compressing objects:  34% (54/157)        
remote: Compressing objects:  35% (55/157)        
remote: Compressing objects:  36% (57/157)        
remote: Compressing objects:  37% (59/157)        
remote: Compressing objects:  38% (60/157)        
remote: Compressing objects:  39% (62/157)        
remote: Compressing objects:  40% (63/157)        
remote: Compressing objects:  41% (65/157)        
remote: Compressing objects:  42% (66/157)        
remote: Compressing objects:  43% (68/157)        
remote: Compressing objects:  44% (70/157)        
remote: Compressing objects:  45% (71/157)        
remote: Compressing objects:  46% (73/157)        
remote: Compressing objects:  47% (74/157)        
remote: Compressing objects:  48% (76/157)        
remote: Compressing objects:  49% (77/157)        
remote: Compressing objects:  50% (79/157)        
remote: Compressing objects:  51% (81/157)        
remote: Compressing objects:  52% (82/157)        
remote: Compressing objects:  53% (84/157)        
remote: Compressing objects:  54% (85/157)        
remote: Compressing objects:  55% (87/157)        
remote: Compressing objects:  56% (88/157)        
remote: Compressing objects:  57% (90/157)        
remote: Compressing objects:  58% (92/157)        
remote: Compressing objects:  59% (93/157)        
remote: Compressing objects:  60% (95/157)        
remote: Compressing objects:  61% (96/157)        
remote: Compressing objects:  62% (98/157)        
remote: Compressing objects:  63% (99/157)        
remote: Compressing objects:  64% (101/157)        
remote: Compressing objects:  65% (103/157)        
remote: Compressing objects:  66% (104/157)        
remote: Compressing objects:  67% (106/157)        
remote: Compressing objects:  68% (107/157)        
remote: Compressing objects:  69% (109/157)        
remote: Compressing objects:  70% (110/157)        
remote: Compressing objects:  71% (112/157)        
remote: Compressing objects:  72% (114/157)        
remote: Compressing objects:  73% (115/157)        
remote: Compressing objects:  74% (117/157)        
remote: Compressing objects:  75% (118/157)        
remote: Compressing objects:  76% (120/157)        
remote: Compressing objects:  77% (121/157)        
remote: Compressing objects:  78% (123/157)        
remote: Compressing objects:  79% (125/157)        
remote: Compressing objects:  80% (126/157)        
remote: Compressing objects:  81% (128/157)        
remote: Compressing objects:  82% (129/157)        
remote: Compressing objects:  83% (131/157)        
remote: Compressing objects:  84% (132/157)        
remote: Compressing objects:  85% (134/157)        
remote: Compressing objects:  86% (136/157)        
remote: Compressing objects:  87% (137/157)        
remote: Compressing objects:  88% (139/157)        
remote: Compressing objects:  89% (140/157)        
remote: Compressing objects:  90% (142/157)        
remote: Compressing objects:  91% (143/157)        
remote: Compressing objects:  92% (145/157)        
remote: Compressing objects:  93% (147/157)        
remote: Compressing objects:  94% (148/157)        
remote: Compressing objects:  95% (150/157)        
remote: Compressing objects:  96% (151/157)        
remote: Compressing object
Receiving objects:   0% (1/523909)
Receiving objects:   1% (5240/523909)
Receiving objects:   2% (10479/523909)
Receiving objects:   3% (15718/523909)
Receiving objects:   4% (20957/523909)
Receiving objects:   5% (26196/523909)
Receiving objects:   6% (31435/523909)
Receiving objects:   7% (36674/523909), 10.03 MiB | 20.05 MiB/s
Receiving objects:   8% (41913/523909), 10.03 MiB | 20.05 MiB/s
Receiving objects:   9% (47152/523909), 10.03 MiB | 20.05 MiB/s
Receiving objects:  10% (52391/523909), 10.03 MiB | 20.05 MiB/s
Receiving objects:  11% (57630/523909), 10.03 MiB | 20.05 MiB/s
Receiving objects:  11% (58738/523909), 19.35 MiB | 19.34 MiB/s
Receiving objects:  12% (62870/523909), 19.35 MiB | 19.34 MiB/s
Receiving objects:  13% (68109/523909), 19.35 MiB | 19.34 MiB/s
Receiving objects:  14% (73348/523909), 19.35 MiB | 19.34 MiB/s
Receiving objects:  15% (78587/523909), 33.29 MiB | 22.19 MiB/s
Receiving objects:  16% (83826/523909), 33.29 MiB | 22.19 MiB/s
Receiving objects:  17% (89065/523909), 33.29 MiB | 22.19 MiB/s
Receiving objects:  18% (94304/523909), 33.29 MiB | 22.19 MiB/s
Receiving objects:  18% (95877/523909), 33.29 MiB | 22.19 MiB/s
Receiving objects:  19% (99543/523909), 47.41 MiB | 23.70 MiB/s
Receiving objects:  20% (104782/523909), 47.41 MiB | 23.70 MiB/s
Receiving objects:  21% (110021/523909), 47.41 MiB | 23.70 MiB/s
Receiving objects:  22% (115260/523909), 62.43 MiB | 24.96 MiB/s
Receiving objects:  23% (120500/523909), 62.43 MiB | 24.96 MiB/s
Receiving objects:  24% (125739/523909), 62.43 MiB | 24.96 MiB/s
Receiving objects:  25% (130978/523909), 62.43 MiB | 24.96 MiB/s
Receiving objects:  26% (136217/523909), 62.43 MiB | 24.96 MiB/s
Receiving objects:  26% (140289/523909), 78.00 MiB | 25.93 MiB/s
Receiving objects:  27% (141456/523909), 78.00 MiB | 25.93 MiB/s
Receiving objects:  28% (146695/523909), 78.00 MiB | 25.93 MiB/s
Receiving objects:  29% (151934/523909), 86.22 MiB | 24.58 MiB/s
Receiving objects:  30% (157173/523909), 86.22 MiB | 24.58 MiB/s
Receiving objects:  31% (162412/523909), 86.22 MiB | 24.58 MiB/s
Receiving objects:  31% (163851/523909), 86.22 MiB | 24.58 MiB/s
Receiving objects:  32% (167651/523909), 104.13 MiB | 25.98 MiB/s
Receiving objects:  33% (172890/523909), 104.13 MiB | 25.98 MiB/s
Receiving objects:  34% (178130/523909), 104.13 MiB | 25.98 MiB/s
Receiving objects:  35% (183369/523909), 104.13 MiB | 25.98 MiB/s
Receiving objects:  36% (188608/523909), 122.20 MiB | 27.11 MiB/s
Receiving objects:  37% (193847/523909), 122.20 MiB | 27.11 MiB/s
Receiving objects:  38% (199086/523909), 122.20 MiB | 27.11 MiB/s
Receiving objects:  39% (204325/523909), 122.20 MiB | 27.11 MiB/s
Receiving objects:  40% (209564/523909), 122.20 MiB | 27.11 MiB/s
Receiving objects:  40% (209933/523909), 122.20 MiB | 27.11 MiB/s
Receiving objects:  41% (214803/523909), 139.48 MiB | 28.72 MiB/s
Receiving objects:  42% (220042/523909), 139.48 MiB | 28.72 MiB/s
Receiving objects:  43% (225281/523909), 139.48 MiB | 28.72 MiB/s
Receiving objects:  44% (230520/523909), 139.48 MiB | 28.72 MiB/s
Receiving objects:  45% (235760/523909), 139.48 MiB | 28.72 MiB/s
Receiving objects:  46% (240999/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  47% (246238/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  48% (251477/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  49% (256716/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  50% (261955/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  51% (267194/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  52% (272433/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  53% (277672/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  54% (282911/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  54% (282935/523909), 155.95 MiB | 30.15 MiB/s
Receiving objects:  55% (288150/523909), 172.99 MiB | 30.83 MiB/s
Receiving objects:  56% (293390/523909), 172.99 MiB | 30.83 MiB/s
Receiving objects:  57% (298629/523909), 172.99 MiB | 30.83 MiB/s
Receiving objects:  58% (303868/523909), 172.99 MiB | 30.83 MiB/s
Receiving objects:  59% (309107/523909), 172.99 MiB | 30.83 MiB/s
Receiving objects:  60% (314346/523909), 172.99 MiB | 30.83 MiB/s
Receiving objects:  61% (319585/523909), 172.99 MiB | 30.83 MiB/s
Receiving objects:  62% (324824/523909), 191.56 MiB | 31.81 MiB/s
Receiving objects:  63% (330063/523909), 191.56 MiB | 31.81 MiB/s
Receiving objects:  64% (335302/523909), 191.56 MiB | 31.81 MiB/s
Receiving objects:  65% (340541/523909), 191.56 MiB | 31.81 MiB/s
Receiving objects:  66% (345780/523909), 191.56 MiB | 31.81 MiB/s
Receiving objects:  67% (351020/523909), 191.56 MiB | 31.81 MiB/s
Receiving objects:  68% (356259/523909), 191.56 MiB | 31.81 MiB/s
Receiving objects:  69% (361498/523909), 191.56 MiB | 31.81 MiB/s
Receiving objects:  69% (362909/523909), 191.56 MiB | 31.81 MiB/s
Receiving objects:  70% (366737/523909), 208.01 MiB | 32.13 MiB/s
Receiving objects:  71% (371976/523909), 208.01 MiB | 32.13 MiB/s
Receiving objects:  72% 
fatal: the remote end hung up unexpectedly
fatal: early EOF
fatal: index-pack failed


exit status 128


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@aocsa aocsa changed the title Add support for sparse complex tensors Add support for sparse complex tensors for CPU/CUDA Jan 23, 2021
@rgommers rgommers added the module: sparse Related to torch.sparse label Jan 23, 2021
8000
Copy link
Collaborator
@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. This PR adds very basic support for complex COO tensors.

I think that the minimal complex sparse support should pass most of the sparse tests. I suggest applying the following patch

diff --git a/test/test_sparse.py b/test/test_sparse.py
index 41ebd60b41..ee1bd58d9f 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -48,6 +48,8 @@ def cuda_only(inner):
 
 class TestSparse(TestCase):
 
+    value_dtype = torch.float64
+
     def setUp(self):
         # These parameters control the various ways we can run the test.
         # We will subclass and override this method to implement CUDA
@@ -56,7 +58,6 @@ class TestSparse(TestCase):
         self.is_uncoalesced = False
         self.device = 'cpu'
         self.exact_dtype = True
-        self.value_dtype = torch.float64
         self.index_tensor = lambda *args: torch.tensor(*args, dtype=torch.int64, device=self.device)
         self.value_empty = lambda *args: torch.empty(*args, dtype=self.value_dtype, device=self.device)
         self.value_tensor = lambda *args: torch.tensor(*args, dtype=self.value_dtype, device=self.device)
@@ -3185,6 +3186,10 @@ class TestSparse(TestCase):
         self.assertRaises(TypeError, assign_to)
 
 
+class TestSparseComplex(TestSparse):
+    value_dtype = torch.complex128
+
+
 class TestUncoalescedSparse(TestSparse):
     def setUp(self):
         super(TestUncoalescedSparse, self).setUp()

to see what features are missing.

@heitorschueroff heitorschueroff added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 26, 2021
@aocsa aocsa force-pushed the aocsa/sparse_complex branch 2 times, most recently from 1a0434f to c0a8bef Compare January 28, 2021 17:01
@aocsa
Copy link
Contributor Author
aocsa commented Jan 28, 2021

Looks good. This PR adds very basic support for complex COO tensors.

I think that the minimal complex sparse support should pass most of the sparse tests. I suggest applying the following patch

to see what features are missing.

I addressed the internal review from @pearu. Specifically, this PR fix some errors in the tests when the new class TestSparseComplex is enabled. I think this PR is ready to be merged and the following PRs should cover other sparse operations to have a minimal complex sparse support and should be done in other PRs because It will involve a lot of more changes, specifically related with the use of specific APIs for accelerated linear algebra using cuSparse and MKL. cc
@anjali411, @mruberry, @rgommers.

Note: This is the last report when TestSparseComplex is fully enabled.

======================================================================
test_dsmm (__main__.TestSparseComplex)
----------------------------------------------------------------------
    res = torch.dsmm(x, y)
======================================================================
test_hsmm (__main__.TestSparseComplex)
----------------------------------------------------------------------
    res = torch.hsmm(x, y)
======================================================================
test_pickle (__main__.TestSparseComplex)
    values = torch.arange(values_numel, dtype=self.value_dtype,

======================================================================
test_print (__main__.TestSparseComplex)
----------------------------------------------------------------------
    values = torch.arange(values_numel, dtype=self.value_dtype,

@anjali411 anjali411 added the module: complex Related to complex number support in PyTorch label Jan 28, 2021
@anjali411
Copy link
Contributor

@aocsa thanks for the PR! I'll try to review this PR this week but in the meantime please undo some third-party commits from your PR that you might have added by mistake :)

@codecov
Copy link
codecov bot commented Jan 28, 2021

Codecov Report

Merging #50984 (d234b49) into master (fb7bab9) will decrease coverage by 0.00%.
The diff coverage is 75.55%.

❗ Current head d234b49 differs from pull request most recent head 9dfa1d8. Consider uploading reports for the commit 9dfa1d8 to get more accurate results

@@            Coverage Diff             @@
##           master   #50984      +/-   ##
==========================================
- Coverage   77.32%   77.32%   -0.01%     
==========================================
  Files        1888     1888              
  Lines      185065   185083      +18     
==========================================
+ Hits       143105   143112       +7     
- Misses      41960    41971      +11     

@aocsa aocsa force-pushed the aocsa/sparse_complex branch 2 times, most recently from e6bf740 to 0912cd5 Compare January 28, 2021 23:08
Copy link
Collaborator
@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @aocsa Looks good**2!

I'd like to propose implementing support for print and pickle for sparse complex tensors. In case this will require a (more complicated) fix for general complex tensors support, then nevermind as it could be done in a separate PR.

Also, I'd suggest testing sparse complex tensor support for uncoaleasced as well as CUDA device cases.

@aocsa aocsa force-pushed the aocsa/sparse_complex branch from c6432e3 to 5550442 Compare January 29, 2021 16:37
Copy link
Collaborator
@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks, @aocsa !

at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_sparse_cuda", [&] {
if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
if (value.to<scalar_t>() != scalar_t(1)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should continue using the static_cast here so please revert this change back

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to maintain that, but when static_cast is enabled it doesn't compile with complex data type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an idea, if c10::complex (I presume this corresponds to scalar_t if complex) would define a constructor using int as input, wouldn't static_cast work then?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this still needs to be addressed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the changes where it was possible, I was able to use static_cast with complex dtype when cpp compiler is used, however when nvcc compiler is used this doesn't work, I tried to use the strategy suggested by pearu but it generate conflicts with other functions. So I guess the issue is with nvcc compiler as in other cases 😔

@@ -8,3 +8,6 @@

#include <TH/generic/THBlas.cpp>
#include <TH/THGenerateHalfType.h>

#include <TH/generic/THBlas.cpp>
#include <TH/THGenerateComplexTypes.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have refrained from adding the TH code for complex so far barring the storage code which was blocking the autograd work. I think it's ok to add this however it's strongly preferable to migrate the functions from TH to ATen and then add complex support.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aocsa please port swap, copy and axpy functions to ATen before adding complex support to them (can be done in a separate PR, and this can be rebased on top). There is absolutely no reason to add complex support to TH in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I create a new PR to port copy and axpy functions, I didn't port swap as it is not used anywhere.
#52345

Copy link
Contributor
@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comments:

  1. Addition of TH code logic for complex. How many other TH functions do you plan to enable for complex?
  2. Needs tests for coalesce, nonzero, add and to_dense (forward and autograd).

aocsa added a commit that referenced this pull request Apr 9, 2021
… constructors for CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 9, 2021
…r CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 9, 2021
… constructors for CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 9, 2021
…r CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 9, 2021
… constructors for CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 9, 2021
…r CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 9, 2021
… constructors for CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 9, 2021
…r CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 9, 2021
… constructors for CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 9, 2021
…r CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 10, 2021
… constructors for CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 10, 2021
…r CPU/CUDA"


Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 14, 2021
… constructors for CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 14, 2021
… constructors for CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 14, 2021
…r CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984


[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 14, 2021
… constructors for CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Differential Revision: [D27765618](https://our.internmc.facebook.com/intern/diff/D27765618)

[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 14, 2021
…r CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Differential Revision: [D27765618](https://our.internmc.facebook.com/intern/diff/D27765618)

[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 17, 2021
… constructors for CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Differential Revision: [D27765618](https://our.internmc.facebook.com/intern/diff/D27765618)

[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 17, 2021
…r CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Differential Revision: [D27765618](https://our.internmc.facebook.com/intern/diff/D27765618)

[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 19, 2021
… constructors for CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Differential Revision: [D27765618](https://our.internmc.facebook.com/intern/diff/D27765618)

[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 19, 2021
…r CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Differential Revision: [D27765618](https://our.internmc.facebook.com/intern/diff/D27765618)

[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 20, 2021
… constructors for CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor *forward/backward* function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Differential Revision: [D27765618](https://our.internmc.facebook.com/intern/diff/D27765618)

[ghstack-poisoned]
aocsa added a commit that referenced this pull request Apr 20, 2021
…r CPU/CUDA"


Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor *forward/backward* function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Differential Revision: [D27765618](https://our.internmc.facebook.com/intern/diff/D27765618)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Apr 27, 2021
…54153)

Summary:
Pull Request resolved: #54153

Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D27765618

Pulled By: ezyang

fbshipit-source-id: a9cdd31d5c7a7dafd790f6cc148f3df26e884c89
aocsa added a commit to Quansight/pytorch that referenced this pull request May 3, 2021
…ytorch#54153)

Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  pytorch#50984

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D27765618

Pulled By: ezyang

fbshipit-source-id: a9cdd31d5c7a7dafd790f6cc148f3df26e884c89

redo stack
aocsa added a commit that referenced this pull request May 3, 2021
…54153)

Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  #50984

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D27765618

Pulled By: ezyang

fbshipit-source-id: a9cdd31d5c7a7dafd790f6cc148f3df26e884c89

redo stack

[ghstack-poisoned]
crcrpar pushed a commit to crcrpar/pytorch that referenced this pull request May 7, 2021
…ytorch#54153)

Summary:
Pull Request resolved: pytorch#54153

Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  pytorch#50984

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D27765618

Pulled By: ezyang

fbshipit-source-id: a9cdd31d5c7a7dafd790f6cc148f3df26e884c89
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
…ytorch#54153)

Summary:
Pull Request resolved: pytorch#54153

Currently, sparse tensors only support real floating point tensors. Complex support is added in this PR for CPU/CUDA.

- [x] add complex support (torch.cfloat and torch.cdouble) to torch.sparse_coo_tensor constructors
- [x] add complex support to coalesce function
- [x] add complex support to to_dense function
- [x] add complex support to to_sparse function
- [x] add complex support to sparse_add function
- [x] add unit tests

Note: This PR contains only complex support for torch.sparse_coo_tensor fordward function and the related ops used with this function (coalesce, to_dense, to_sparse, and sparse_add). The following PRs in ghstack should cover other sparse operations to have a more complex sparse support, specifically related with the use of specific APIs for accelerated linear algebra.

Note: Before using ghstack the original PR  was  pytorch#50984

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D27765618

Pulled By: ezyang

fbshipit-source-id: a9cdd31d5c7a7dafd790f6cc148f3df26e884c89
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: complex Related to complex number support in PyTorch module: sparse Related to torch.sparse open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add support for sparse complex tensors
10 participants
0