-
Notifications
You must be signed in to change notification settings - Fork 24.8k
OpInfo: mvlgamma #56907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OpInfo: mvlgamma #56907
Changes from all commits
12d81c6
562f805
9b00266
89bd876
a351670
f5b7390
1c3ace7
7157e44
6622f60
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2853,6 +2853,67 @@ def generator(): | |
return list(generator()) | ||
|
||
|
||
def sample_inputs_mvlgamma(op_info, device, dtype, requires_grad, **kwargs): | ||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) | ||
tensor_shapes = ((S, S), ()) | ||
ns = (1, 2, 3, 4, 5) | ||
|
||
# Since the accepted lower bound for input | ||
# to mvlgamma depends on `p` argument, | ||
# the following function computes the lower bound | ||
# which we pass to `make_tensor`. | ||
def compute_min_val(p): | ||
return (p - 1.) / 2 | ||
|
||
def generator(): | ||
for shape, n in product(tensor_shapes, ns): | ||
min_val = compute_min_val(n) | ||
yield SampleInput(make_arg(shape, low=min_val), args=(n,)) | ||
|
||
return list(generator()) | ||
|
||
|
||
# Since `mvlgamma` has multiple entries, | ||
# there are multiple common skips for the additional | ||
# entries. Following function is a helper to that end. | ||
def skips_mvlgamma(skip_redundant=False): | ||
skips = ( | ||
# outside domain values are hard error for mvlgamma op. | ||
SkipInfo('TestUnaryUfuncs', 'test_float_domains'), | ||
) | ||
if not skip_redundant: | ||
# Redundant tests | ||
skips = skips + ( # type: ignore[assignment] | ||
SkipInfo('TestGradients'), | ||
SkipInfo('TestOpInfo'), | ||
SkipInfo('TestCommon'), | ||
) | ||
return skips | ||
|
||
|
||
# To test reference numerics against multiple values of argument `p`, | ||
# we make multiple OpInfo entries with each entry corresponding to different value of p. | ||
# We run the op tests from test_ops.py only for `p=1` to avoid redundancy in testing. | ||
# Class `MvlGammaInfo` already contains the basic information related to the operator, | ||
# it only takes arguments like `domain`, `skips` and `sample_kwargs`, which | ||
# differ between the entries. | ||
class MvlGammaInfo(UnaryUfuncInfo): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment for why this class is helpful There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
def __init__(self, variant_test_name, domain, skips, sample_kwargs): | ||
super(MvlGammaInfo, self).__init__( | ||
'mvlgamma', | ||
ref=reference_mvlgamma if TEST_SCIPY else _NOTHING, | ||
variant_test_name=variant_test_name, | ||
domain=domain, | ||
decorators=(precisionOverride({torch.float16: 5e-2}),), | ||
dtypes=floating_types(), | ||
dtypesIfCPU=floating_types(), | ||
dtypesIfCUDA=floating_types_and(torch.half), | ||
sample_inputs_func=sample_inputs_mvlgamma, | ||
supports_out=False, | ||
skips=skips, | ||
sample_kwargs=sample_kwargs) | ||
|
||
|
||
def sample_inputs_entr(op_info, device, dtype, requires_grad, **kwargs): | ||
low, _ = op_info.domain | ||
|
||
|
@@ -3330,6 +3391,14 @@ def reference_polygamma(x, n): | |
np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] | ||
return scipy.special.polygamma(n, x).astype(np_dtype) | ||
|
||
|
||
def reference_mvlgamma(x, d): | ||
if x.dtype == np.float16: | ||
return scipy.special.multigammaln(x, d).astype(np.float16) | ||
|
||
return scipy.special.multigammaln(x, d) | ||
|
||
|
||
def gradcheck_wrapper_hermitian_input(op, input, *args, **kwargs): | ||
"""Gradcheck wrapper for functions that take Hermitian matrices as input. | ||
|
||
|
@@ -4541,6 +4610,22 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): | |
op=torch.mode, | ||
dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), | ||
sample_inputs_func=sample_inputs_mode,), | ||
MvlGammaInfo(variant_test_name='mvlgamma_p_1', | ||
domain=(1e-4, float('inf')), | ||
skips=skips_mvlgamma(), | ||
sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), | ||
MvlGammaInfo(variant_test_name='mvlgamma_p_3', | ||
domain=(1.1, float('inf')), | ||
skips=skips_mvlgamma(skip_redundant=True) + ( | ||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', dtypes=(torch.float16,)), | ||
), | ||
sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), | ||
MvlGammaInfo(variant_test_name='mvlgamma_p_5', | ||
domain=(2.1, float('inf')), | ||
skips=skips_mvlgamma(skip_redundant=True) + ( | ||
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics_hard', dtypes=(torch.float16,)), | ||
), | ||
sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})), | ||
OpInfo('ne', | ||
aliases=('not_equal',), | ||
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), | ||
|
@@ -5105,6 +5190,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): | |
OpInfo('polar', | ||
dtypes=floating_types(), | ||
sample_inputs_func=sample_inputs_polar), | ||
# TODO(@kshitij12345): Refactor similar to `mvlgamma` entries. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a TODO here for |
||
# To test reference numerics against multiple values of argument `n`, | ||
# we make multiple OpInfo entries with each entry corresponding to different value of n (currently 0 to 4). | ||
# We run the op tests from test_ops.py only for `n=0` to avoid redundancy in testing. | ||
|
@@ -5837,10 +5923,6 @@ def method_tests(): | |
('renorm', (S, S, S), (1, 2, 3), 'norm_1'), | ||
('renorm', (S, S, S), (inf, 2, 0.5), 'norm_inf'), | ||
('log_softmax', (S, S, S), (1, torch.float64,), 'kwarg_dtype_would_break_jit_loader', (True,)), | ||
('mvlgamma', torch.empty(S,).uniform_(0.5, 1), [1], "p=1"), | ||
('mvlgamma', torch.empty(S,).uniform_(1, 2), [2], "p=2"), | ||
('mvlgamma', torch.empty(S, S).uniform_(1.5, 3), [3], "p=3"), | ||
('mvlgamma', torch.empty(S, S).uniform_(2.5, 5), [5], "p=5"), | ||
('zero_', (S, S, S), NO_ARGS), | ||
('zero_', (), NO_ARGS, 'scalar'), | ||
('norm', (S, S), (), 'default'), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really interesting and a possible follow-up issue; maybe a good first issue for a new contributor?