|
| 1 | +#include "tensorflow/core/common_runtime/copy_tensor.h" |
| 2 | + |
| 3 | +#include <vector> |
| 4 | +#include "tensorflow/core/lib/core/errors.h" |
| 5 | +#include "tensorflow/core/platform/logging.h" |
| 6 | +#include "tensorflow/core/platform/tracing.h" |
| 7 | + |
| 8 | +namespace tensorflow { |
| 9 | +namespace { |
| 10 | + |
| 11 | +static bool initialization_done = false; |
| 12 | + |
| 13 | +struct RegistrationInfo { |
| 14 | + RegistrationInfo(DeviceType s, DeviceType r, CopyTensor::CopyFunction cf) |
| 15 | + : sender_device_type(s), receiver_device_type(r), copy_function(cf) {} |
| 16 | + DeviceType sender_device_type; |
| 17 | + DeviceType receiver_device_type; |
| 18 | + CopyTensor::CopyFunction copy_function; |
| 19 | +}; |
| 20 | + |
| 21 | +// We use a vector instead of a map since we expect there to be very |
| 22 | +// few registrations. |
| 23 | +std::vector<RegistrationInfo>* MutableRegistry() { |
| 24 | + static std::vector<RegistrationInfo>* registry = |
| 25 | + new std::vector<RegistrationInfo>; |
| 26 | + return registry; |
| 27 | +} |
| 28 | + |
| 29 | +} // namespace |
| 30 | + |
| 31 | +// static |
| 32 | +void CopyTensor::ViaDMA(const string& edge_name, |
| 33 | + DeviceContext* send_dev_context, |
| 34 | + DeviceContext* recv_dev_context, Device* src, |
| 35 | + Device* dst, const AllocatorAttributes src_alloc_attr, |
| 36 | + const AllocatorAttributes dst_alloc_attr, |
| 37 | + const Tensor* input, Tensor* output, |
| 38 | + StatusCallback done) { |
| 39 | + initialization_done = true; |
| 40 | + port::Tracing::ScopedAnnotation annotation(edge_name); |
| 41 | + VLOG(1) << "CopyViaDMA " << edge_name; |
| 42 | + const size_t total_bytes = input->TotalBytes(); |
| 43 | + |
| 44 | + // Note that 0-size tensors have no backing buffer. |
| 45 | + if (total_bytes > 0) { |
| 46 | + const DeviceType src_device_type(src_alloc_attr.on_host() |
| 47 | + ? DEVICE_CPU |
| 48 | + : src->attributes().device_type()); |
| 49 | + const DeviceType dst_device_type(dst_alloc_attr.on_host() |
| 50 | + ? DEVICE_CPU |
| 51 | + : dst->attributes().device_type()); |
| 52 | + const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU); |
| 53 | + const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU); |
| 54 | + |
| 55 | + if (non_cpu_src) { |
| 56 | + if (non_cpu_dst) { |
| 57 | + // Device to device copy. Look through registry for an appropriate |
| 58 | + // CopyFunction. |
| 59 | + std::vector<RegistrationInfo>* registry = MutableRegistry(); |
| 60 | + for (const RegistrationInfo& ri : *registry) { |
| 61 | + if (ri.sender_device_type == src_device_type && |
| 62 | + ri.receiver_device_type == dst_device_type) { |
| 63 | + ri.copy_function(send_dev_context, recv_dev_context, src, dst, |
| 64 | + src_alloc_attr, dst_alloc_attr, input, output, |
| 65 | + done); |
| 66 | + return; |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + // TODO(josh11b): If no CopyFunction is found, we currently fail |
| 71 | + // but we could copy between devices via CPU. |
| 72 | + done(errors::Unimplemented( |
| 73 | + "No function registered to copy from devices of type ", |
| 74 | + src_device_type.type(), " to devices of type ", |
| 75 | + dst_device_type.type())); |
| 76 | + } else { |
| 77 | + // Device to host copy. |
| 78 | + return send_dev_context->CopyDeviceTensorToCPU(input, edge_name, src, |
| 79 | + output, done); |
| 80 | + } |
| 81 | + } else if (non_cpu_dst) { |
| 82 | + // Host to Device copy. |
| 83 | + // Note that this is already an async copy. |
| 84 | + recv_dev_context->CopyCPUTensorToDevice(input, dst, output, done); |
| 85 | + } else { |
| 86 | + *output = *input; |
| 87 | + done(Status::OK()); |
| 88 | + } |
| 89 | + } else { |
| 90 | + // buffer is empty |
| 91 | + done(Status::OK()); |
| 92 | + } |
| 93 | +} |
| 94 | + |
| 95 | +// static |
| 96 | +Status CopyTensor::Register(DeviceType sender_device_type, |
| 97 | + DeviceType receiver_device_type, |
| 98 | + CopyFunction copy_function) { |
| 99 | + if (initialization_done) { |
| 100 | + return errors::FailedPrecondition( |
| 101 | + "May only register CopyTensor functions during before the first tensor " |
| 102 | + "is copied."); |
| 103 | + } |
| 104 | + std::vector<RegistrationInfo>* registry = MutableRegistry(); |
| 105 | + registry->emplace_back(sender_device_type, receiver_device_type, |
| 106 | + copy_function); |
| 107 | + return Status::OK(); |
| 108 | +} |
| 109 | + |
| 110 | +} // namespace tensorflow |
0 commit comments