8000 Set correct number of dimensions for convolve2_nn and convolve2_gradi… · arrayfire/arrayfire-rust@5209591 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5209591

Browse files
authored
Set correct number of dimensions for convolve2_nn and convolve2_gradient_nn (#324)
* Set correct number of dimensions for convolve2_nn and convolve2_gradient_nn Fixes issue #323 * fix clippy lints
1 parent 140ee83 commit 5209591

File tree

3 files changed

+19
-27
lines changed

3 files changed

+19
-27
lines changed

build.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ fn run_cmake_command(conf: &Config, build_dir: &std::path::Path) {
281281
run(
282282
make_cmd
283283
.arg(format!("-j{}", conf.build_threads))
284-
.arg("install".to_string()),
284+
.arg("install"),
285285
"make",
286286
);
287287
}

opencl-interop/src/lib.rs

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,35 +88,27 @@ pub fn get_device_id() -> cl_device_id {
8888
}
8989

9090
/// Set the cl_device_id as the active ArrayFire OpenCL device
91-
pub fn set_device_id(dev_id: cl_device_id) {
92-
unsafe {
93-
let err_val = afcl_set_device_id(dev_id);
94-
handle_error_general(AfError::from(err_val));
95-
}
91+
pub unsafe fn set_device_id(dev_id: cl_device_id) {
92+
let err_val = afcl_set_device_id(dev_id);
93+
handle_error_general(AfError::from(err_val));
9694
}
9795

9896
/// Push user provided device, context and queue tuple to ArrayFire device mamanger
99-
pub fn add_device_context(dev_id: cl_device_id, ctx: cl_context, queue: cl_command_queue) {
100-
unsafe {
101-
let err_val = afcl_add_device_context(dev_id, ctx, queue);
102-
handle_error_general(AfError::from(err_val));
103-
}
97+
pub unsafe fn add_device_context(dev_id: cl_device_id, ctx: cl_context, queue: cl_command_queue) {
98+
let err_val = afcl_add_device_context(dev_id, ctx, queue);
99+
handle_error_general(AfError::from(err_val));
104100
}
105101

106102
/// Set the device identified by device & context pair as the active device for ArrayFire
107-
pub fn set_device_context(dev_id: cl_device_id, ctx: cl_context) {
108-
unsafe {
109-
let err_val = afcl_set_device_context(dev_id, ctx);
110-
handle_error_general(AfError::from(err_val));
111-
}
103+
pub unsafe fn set_device_context(dev_id: cl_device_id, ctx: cl_context) {
104+
let err_val = afcl_set_device_context(dev_id, ctx);
105+
handle_error_general(AfError::from(err_val));
112106
}
113107

114108
/// Remove the user provided device, context pair from ArrayFire device mamanger
115-
pub fn delete_device_context(dev_id: cl_device_id, ctx: cl_context) {
116-
unsafe {
117-
let err_val = afcl_delete_device_context(dev_id, ctx);
118-
handle_error_general(AfError::from(err_val));
119-
}
109+
pub unsafe fn delete_device_context(dev_id: cl_device_id, ctx: cl_context) {
110+
let err_val = afcl_delete_device_context(dev_id, ctx);
111+
handle_error_general(AfError::from(err_val));
120112
}
121113

122114
///// Fetch Active ArrayFire device's type i.e. CPU/GPU/Accelerator etc.

src/ml/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,11 @@ where
7676
&mut temp as *mut af_array,
7777
signal.get(),
7878
filter.get(),
79-
4,
79+
2,
8080
strides.get().as_ptr() as *const dim_t,
81-
4,
81+
2,
8282
padding.get().as_ptr() as *const dim_t,
83-
4,
83+
2,
8484
dilation.get().as_ptr() as *const dim_t,
8585
);
8686
HANDLE_ERROR(AfError::from(err_val));
@@ -126,11 +126,11 @@ where
126126
original_signal.get(),
127127
original_filter.get(),
128128
convolved_output.get(),
129-
4,
129+
2,
130130
strides.get().as_ptr() as *const dim_t,
131-
4,
131+
2,
132132
padding.get().as_ptr() as *const dim_t,
133-
4,
133+
2,
134134
dilation.get().as_ptr() as *const dim_t,
135135
grad_type as c_uint,
136136
);

0 commit comments

Comments
 (0)
0