-
Notifications
You must be signed in to change notification settings - Fork 3.5k
torch fix casting and add ops for sd vae(s) #9297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This branch currently is behind tinygrad/master. The line count difference bot is disabled. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you know why sdxl output is incorrect?
@@ -142,7 +142,8 @@ def convolution_backward_overrideable(grad_out, input, weight, stride, padding, | |||
@torch.library.impl("aten::_copy_from", "privateuseone") | |||
def _copy_from(src, dest, non_blocking=False): | |||
if str(src.device) == "tiny" and str(dest.device) == "tiny": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the cast should happen regardless of the device (before if blocks) right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch cast recurses back into this (at least when I tried it crashed without traceback…). need to know it’s a tiny device to do the cast, although now looking again I’m not sure if that 3rd block does the cast implicitly actually
I’ll add a test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added cast in all branches, after src is converted to tiny tensor one way or another
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change does seem to fix ~10 additional tests in #9302 which were failing due to dtype mismatch.
narrowed it down to the encoder but no further, it may fix itself with the test_ops fixes but if not I’ll revisit |
cool thanks! |
* fix some tests in test_ops for torch backend(171 failing) * fix more tests (135 failures) * fix tests (126 failing) * handle transposed convs (109 tests failing) * fix slice * fix lshift & rshift and more tests (87 tests failing) * revert accidental change * remove unnecessary changes (82 failures) * fix backward for avg_pool2d (78 failures) * fix backward for avg_pool2d (78 failures) * fix replication backpass * fix reflection pad back pass (71 failures) * cummax with indicies, aten.mv and move out methods (67 failures) * extract avg_pool2d and avg_pool3d to separate functions (62 failures) * revert changes for cat_out * rewrite avg_pool and pad without repetition * remove duplicates from decomps * slice rewrite and add slice_backward (59 failures) * add dtype fixup from #9297 * fix linter error and remove Tensor.pad (48 failures) * add select_backward and index_put (40 failures) * fix some more tests (36 failures) * fix more tests (12 failures) * some cleanups and fix couple more tests (10 failures) * cleaner way to write upsample * some more upsample cleanups * use lambda for upsample * add autowrapper for upsample forward * cumsum and max_dim without aten functions * revert _log_softmax * fix more tests (1 failure) * make linter happy * move import to appropriate func * make linter happy * add codes for noqa * some more refactors * remove comment * remove dependency on aten function for conv backward * some more refactors * add returns * revert a change from merge * some cleanups * remove whitespace * remove ruff change * revert upsample * add masked_fill_.Tensor and scatter.src_out * add todo * fix test_biased_conv2d * fix test_var_one_in_axis & test_std_one_in_axis but break test_biased_conv2d :( * revert torch_debug * revert torch_debug * skip test_gather_failure for the tiny backend * make padding registration more consise * add nonzero * remove scatter_add since we already have the out * fix scatter * remove some repetition * make upsample backward registrations more concise * remove select.int * use Tensor.cumsum * realize conv2d outputs before backward to fix test_biased_conv2d * add a todo for realize(1 failure) * add new_empty and new_empty_strided * make test_pad_circular_mode forward only and remove redundant stuff * fix linter errors * remove expect failure * just tb * slice is a view_op * contiguous only when lazydata.is_realized * fix backward for test_pad_circular_mode * revert torch.nn.functional.pad override * add transpose.int and make constant_pad_nd contiguous * slice_backwards has no kwargs --------- Co-authored-by: chenyu <chenyu@fastmail.com>
* fix some tests in test_ops for torch backend(171 failing) * fix more tests (135 failures) * fix tests (126 failing) * handle transposed convs (109 tests failing) * fix slice * fix lshift & rshift and more tests (87 tests failing) * revert accidental change * remove unnecessary changes (82 failures) * fix backward for avg_pool2d (78 failures) * fix backward for avg_pool2d (78 failures) * fix replication backpass * fix reflection pad back pass (71 failures) * cummax with indicies, aten.mv and move out methods (67 failures) * extract avg_pool2d and avg_pool3d to separate functions (62 failures) * revert changes for cat_out * rewrite avg_pool and pad without repetition * remove duplicates from decomps * slice rewrite and add slice_backward (59 failures) * add dtype fixup from tinygrad/tinygrad#9297 * fix linter error and remove Tensor.pad (48 failures) * add select_backward and index_put (40 failures) * fix some more tests (36 failures) * fix more tests (12 failures) * some cleanups and fix couple more tests (10 failures) * cleaner way to write upsample * some more upsample cleanups * use lambda for upsample * add autowrapper for upsample forward * cumsum and max_dim without aten functions * revert _log_softmax * fix more tests (1 failure) * make linter happy * move import to appropriate func * make linter happy * add codes for noqa * some more refactors * remove comment * remove dependency on aten function for conv backward * some more refactors * add returns * revert a change from merge * some cleanups * remove whitespace * remove ruff change * revert upsample * add masked_fill_.Tensor and scatter.src_out * add todo * fix test_biased_conv2d * fix test_var_one_in_axis & test_std_one_in_axis but break test_biased_conv2d :( * revert torch_debug * revert torch_debug * skip test_gather_failure for the tiny backend * make padding registration more consise * add nonzero * remove scatter_add since we already have the out * fix scatter * remove some repetition * make upsample backward registrations more concise * remove select.int * use Tensor.cumsum * realize conv2d outputs before backward to fix test_biased_conv2d * add a todo for realize(1 failure) * add new_empty and new_empty_strided * make test_pad_circular_mode forward only and remove redundant stuff * fix linter errors * remove expect failure * just tb * slice is a view_op * contiguous only when lazydata.is_realized * fix backward for test_pad_circular_mode * revert torch.nn.functional.pad override * add transpose.int and make constant_pad_nd contiguous * slice_backwards has no kwargs --------- Co-authored-by: chenyu <chenyu@fastmail.com>
* fix some tests in test_ops for torch backend(171 failing) * fix more tests (135 failures) * fix tests (126 failing) * handle transposed convs (109 tests failing) * fix slice * fix lshift & rshift and more tests (87 tests failing) * revert accidental change * remove unnecessary changes (82 failures) * fix backward for avg_pool2d (78 failures) * fix backward for avg_pool2d (78 failures) * fix replication backpass * fix reflection pad back pass (71 failures) * cummax with indicies, aten.mv and move out methods (67 failures) * extract avg_pool2d and avg_pool3d to separate functions (62 failures) * revert changes for cat_out * rewrite avg_pool and pad without repetition * remove duplicates from decomps * slice rewrite and add slice_backward (59 failures) * add dtype fixup from tinygrad/tinygrad#9297 * fix linter error and remove Tensor.pad (48 failures) * add select_backward and index_put (40 failures) * fix some more tests (36 failures) * fix more tests (12 failures) * some cleanups and fix couple more tests (10 failures) * cleaner way to write upsample * some more upsample cleanups * use lambda for upsample * add autowrapper for upsample forward * cumsum and max_dim without aten functions * revert _log_softmax * fix more tests (1 failure) * make linter happy * move import to appropriate func * make linter happy * add codes for noqa * some more refactors * remove comment * remove dependency on aten function for conv backward * some more refactors * add returns * revert a change from merge * some cleanups * remove whitespace * remove ruff change * revert upsample * add masked_fill_.Tensor and scatter.src_out * add todo * fix test_biased_conv2d * fix test_var_one_in_axis & test_std_one_in_axis but break test_biased_conv2d :( * revert torch_debug * revert torch_debug * skip test_gather_failure for the tiny backend * make padding registration more consise * add nonzero * remove scatter_add since we already have the out * fix scatter * remove some repetition * make upsample backward registrations more concise * remove select.int * use Tensor.cumsum * realize conv2d outputs before backward to fix test_biased_conv2d * add a todo for realize(1 failure) * add new_empty and new_empty_strided * make test_pad_circular_mode forward only and remove redundant stuff * fix linter errors * remove expect failure * just tb * slice is a view_op * contiguous only when lazydata.is_realized * fix backward for test_pad_circular_mode * revert torch.nn.functional.pad override * add transpose.int and make constant_pad_nd contiguous * slice_backwards has no kwargs --------- Co-authored-by: chenyu <chenyu@fastmail.com>
Fix copy to cast as that is the expected behavior here.
After this commit, the image AE models from https://github.com/madebyollin/taesd should work. Tested in ComfyUI, quite simple to patch in tinygrad. Tried an actual SDXL VAE and it runs but produces incorrect results.