@@ -68,23 +68,35 @@ inline Tensor pad(
6868 return Tensor (ret0);
6969}
7070
71- // We expect these to be the stable version of the amax.default op
72- // with identical semantics to the existing amax.default op.
71+ // We expect the following two functions to be stable versions of the amax.default op
72+ // with identical semantics to the existing amax.default op. If `keepdim` is true, the
73+ // result will have the same number of dimensions as `self`, with the specified dimension
74+ // having size 1. Otherwise, the result will have one fewer dimension than `self`, with the
75+ // specified dimension removed.
76+
77+ // This function is an overload to compute the maximum value along each slice of `self`
78+ // along a single dimension `dim`.
7379inline Tensor amax (Tensor& self, int64_t dim, bool keepdim = false ) {
74- AtenTensorHandle ret0 = nullptr ;
80+ AtenTensorHandle ret = nullptr ;
7581 TORCH_ERROR_CODE_CHECK (
76- aoti_torch_aten_amax (self.get (), &dim, 1 , keepdim, &ret0 ));
77- return Tensor (ret0 );
82+ aoti_torch_aten_amax (self.get (), &dim, 1 , keepdim, &ret ));
83+ return Tensor (ret );
7884}
7985
86+ // This function is an overload to compute the maximum value along each slice of `self`
87+ // reducing over all the dimensions in the vector `dims`.
88+ // The amax.default op takes in a SymInt[] as the dims argument, however dims is typed as
89+ // use std::vector<int64_t> here because
90+ // (1) IntArrayRef is not yet header-only
91+ // (2) SymInt is not yet header-only
8092inline Tensor amax (
8193 Tensor& self,
8294 std::vector<int64_t > dims,
8395 bool keepdim = false ) {
84- AtenTensorHandle ret0 = nullptr ;
96+ AtenTensorHandle ret = nullptr ;
8597 TORCH_ERROR_CODE_CHECK (aoti_torch_aten_amax (
86- self.get (), dims.data (), ( int64_t ) dims.size (), keepdim, &ret0 ));
87- return Tensor (ret0 );
98+ self.get (), dims.data (), static_cast < int64_t >( dims.size ()) , keepdim, &ret ));
99+ return Tensor (ret );
88100}
89101
90102// We expect this to be the stable version of the transpose op with identical
0 commit comments