Skip to content

Commit 0b8526d

Browse files
Bugfixes:
- Removed data_service_dataset_op metrics dump (as a newer version of TF is required >=2.8) - Switched from dataset key to dataset fingerprint in ShouldUseLocalWorkers
1 parent d7a04b7 commit 0b8526d

File tree

7 files changed

+37
-63
lines changed

7 files changed

+37
-63
lines changed

tensorflow/core/data/service/dispatcher_impl.cc

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -832,9 +832,14 @@ Status DataServiceDispatcherImpl::GetOrCreateJob(
832832
GetOrCreateJobRequest::kNumConsumers) {
833833
num_consumers = request->num_consumers();
834834
}
835+
836+
absl::flat_hash_set<std::string> local_workers;
837+
local_workers.insert(request->local_workers().cbegin(),
838+
request->local_workers().cend());
839+
835840
TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(),
836841
requested_processing_mode, key, num_consumers,
837-
job));
842+
job, local_workers));
838843
int64 job_client_id;
839844
TF_RETURN_IF_ERROR(AcquireJobClientId(job, job_client_id));
840845
response->set_job_client_id(job_client_id);
@@ -943,7 +948,9 @@ Status DataServiceDispatcherImpl::ValidateMatchingJob(
943948
Status DataServiceDispatcherImpl::CreateJob(
944949
int64 dataset_id, ProcessingMode processing_mode,
945950
absl::optional<NamedJobKey> named_job_key,
946-
absl::optional<int64> num_consumers, std::shared_ptr<const Job>& job)
951+
absl::optional<int64> num_consumers,
952+
std::shared_ptr<const Job>& job,
953+
absl::flat_hash_set<std::string> local_workers)
947954
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
948955
switch (processing_mode) {
949956
case ProcessingMode::PARALLEL_EPOCHS:
@@ -1005,24 +1012,23 @@ Status DataServiceDispatcherImpl::CreateJob(
10051012

10061013
bool should_use_local_workers; // Do we have enough throughput to decide to use local workers to save network bandwidth?
10071014
TF_RETURN_IF_ERROR(service::easl::local_workers_utils::ShouldUseLocalWorkers(
1008-
config_, metadata_store_, compute_dataset_key, should_use_local_workers
1015+
config_, metadata_store_, dataset_fingerprint, should_use_local_workers
10091016
));
10101017

1011-
if(should_use_local_workers && request.local_workers().size() >= 1) {
1018+
if(should_use_local_workers && local_workers.size() >= 1) {
10121019
target_remote_workers = suggested_worker_count - 1;
10131020
target_local_workers = 1;
10141021
} else {
10151022
target_remote_workers = suggested_worker_count;
10161023
target_local_workers = 0;
10171024
}
10181025
} else if(config_.scaling_policy() == 2) { // Use all available workers
1019-
target_remote_workers = total_workers - request.local_workers().size();
1020-
target_local_workers = request.local_workers().size();
1026+
target_remote_workers = total_workers - local_workers.size();
1027+
target_local_workers = local_workers.size();
10211028
} else if(config_.scaling_policy() == 3) { // Grid search over local and remote workers
10221029
TF_RETURN_IF_ERROR(service::easl::local_workers_utils::DecideTargetWorkersGridSearch(
1023-
config_, metadata_store_, compute_dataset_key,
1024-
total_workers - request.local_workers().size(), request.local_workers().size(),
1025-
target_remote_workers, target_local_workers
1030+
total_workers - local_workers.size(), local_workers.size(),
1031+
target_remote_workers, target_local_workers // passed by reference
10261032
));
10271033
}
10281034

@@ -1061,7 +1067,7 @@ Status DataServiceDispatcherImpl::CreateJob(
10611067
create_job->set_target_worker_count(suggested_worker_count);
10621068
create_job->set_target_local_workers(target_local_workers);
10631069
create_job->set_target_remote_workers(target_remote_workers);
1064-
*create_job->mutable_local_workers() = {request.local_workers().begin(), request.local_workers().end()};
1070+
*create_job->mutable_local_workers() = {local_workers.begin(), local_workers.end()};
10651071
if (named_job_key.has_value()) {
10661072
NamedJobKeyDef* key = create_job->mutable_named_job_key();
10671073
key->set_name(named_job_key->name);
@@ -1120,7 +1126,7 @@ Status DataServiceDispatcherImpl::CreateTasksForJob(
11201126
std::vector<std::shared_ptr<const Task>>& tasks)
11211127
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
11221128
std::vector<std::shared_ptr<const Worker>> workers = state_.ReserveWorkers(
1123-
job->job_id, job->target_worker_count, job->target_remote_workers, job->target_local_workers, job->local_workers);
1129+
job->job_id, job->target_remote_workers, job->target_local_workers, job->local_workers);
11241130
if (workers.size() < job->target_worker_count){
11251131
VLOG(0)
11261132
<< "EASL - Not enough workers for job. Elasticity policy requires "
@@ -1415,7 +1421,6 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
14151421
}
14161422
response->set_job_finished(job->finished);
14171423
response->set_target_local_workers(job->target_local_workers);
1418-
response->set_target_remote_workers(job->target_remote_workers);
14191424
VLOG(4) << "Found " << response->task_info_size()
14201425
<< " tasks for job client id " << request->job_client_id();
14211426

tensorflow/core/data/service/dispatcher_impl.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ class DataServiceDispatcherImpl {
197197
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
198198
absl::optional<DispatcherState::NamedJobKey> named_job_key,
199199
absl::optional<int64> num_consumers,
200-
std::shared_ptr<const DispatcherState::Job>& job)
200+
std::shared_ptr<const DispatcherState::Job>& job,
201+
absl::flat_hash_set<std::string> local_workers)
201202
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
202203
// Creates tasks for the specified worker, one task for every unfinished job.
203204
Status CreateTasksForWorker(const std::string& worker_address);

tensorflow/core/data/service/dispatcher_state.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,10 @@ DispatcherState::ReserveWorkers(
431431
workers.push_back(it->second);
432432
VLOG(0) << "(ReserveWorkers) Assigning worker at address "
433433
<< it->second->address << " to job " << job_id;
434-
workers_by_job_[job_id].push_back(it->second);
434+
workers_by_job_[job_id][it->second->address] = it->second;
435435
jobs_by_worker_[it->second->address][job_id] = jobs_[job_id];
436436
avail_workers_.erase(it++);
437-
if (target_worker_count == 0)
437+
if (target_local_workers + target_remote_workers == 0)
438438
break;
439439
}
440440
VLOG(0) << "(ReserveWorkers) Number of workers for job " << job_id << " is: "

tensorflow/core/data/service/easl/local_workers_utils.cc

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ namespace local_workers_utils {
1414
Status ShouldUseLocalWorkers(
1515
const experimental::DispatcherConfig& dispatcher_config,
1616
const ::tensorflow::data::easl::MetadataStore& metadata_store,
17-
const std::string& dataset_key,
17+
const int64 dataset_fingerprint,
1818
bool& should_use_local_workers) {
1919
using NodeMetrics = ::tensorflow::data::easl::NodeMetrics;
2020
using ModelMetrics = ::tensorflow::data::easl::ModelMetrics;
2121

2222
// Check if we have any metrics for this dataset
2323
std::shared_ptr<data::easl::InputPipelineMetrics> job_metrics;
24-
Status s = metadata_store.GetLastNodeMetricsByDatasetFingerprint(
25-
dataset_key, job_metrics);
24+
Status s = metadata_store.GetInputPipelineMetricsByDatasetFingerprint(
25+
dataset_fingerprint, job_metrics);
2626

2727
// We do not yet have the metrics for this dataset --> use 1 worker
2828
if(errors::IsNotFound(s)) {
@@ -36,8 +36,7 @@ Status ShouldUseLocalWorkers(
3636

3737
// Pipeline stats: last TF node metrics
3838
std::shared_ptr<NodeMetrics> last_tf_node_metrics;
39-
40-
s = metadata_store.GetLastNodeMetricsByDatasetKey(dataset_key, last_tf_node_metrics);
39+
s = metadata_store.GetLastNodeMetricsByDatasetFingerprint(dataset_fingerprint, last_tf_node_metrics);
4140
if (!s.ok()) {
4241
VLOG(0) << "DSL (ShouldUseLocalWorkers) Failed to get the last TF node metrics";
4342
return s;
@@ -55,9 +54,9 @@ Status ShouldUseLocalWorkers(
5554
VLOG(0) << "DSL (ShouldUseLocalWorkers) Total bytes produced: " << total_bytes_produced << "\n"
5655
<< "Total num elements: " << total_num_elements << "\n"
5756
<< "Avg bytes produced per element: " << avg_bytes_per_element << "\n"
58-
<< "Decision Threshold: " << dispatcher_config.avg_bytes_per_element_local_thres() << "\n";
57+
<< "Decision Threshold: " << dispatcher_config.avg_bytes_per_element_local_workers_threshold() << "\n";
5958

60-
if (avg_bytes_per_element > dispatcher_config.avg_bytes_per_element_local_thres()) {
59+
if (avg_bytes_per_element > dispatcher_config.avg_bytes_per_element_local_workers_threshold()) {
6160
should_use_local_workers = true;
6261
VLOG(0) << "DSL (ShouldUseLocalWorkers) Using local workers! (because avg. bytes per element > threshold) \n";
6362
}
@@ -71,8 +70,14 @@ Status ShouldUseLocalWorkers(
7170

7271
std::vector<int64> records;
7372

74-
void grid_search(int64 num_worker_remote_avail, int64 num_worker_local_avail,
75-
int64& num_worker_remote_target, int64& num_worker_local_target) {
73+
Status DecideTargetWorkersGridSearch(
74+
int64 num_worker_remote_avail,
75+
int64 num_worker_local_avail,
76+
int64& num_worker_remote_target,
77+
int64& num_worker_local_target) {
78+
std::time_t t = std::time(nullptr);
79+
records.push_back(t);
80+
7681
std::vector<std::pair<int64, int64>> test_set = std::vector<std::pair<int64, int64>>();
7782
for(int64 n_r = 0; n_r <= num_worker_remote_avail; n_r++) {
7883
for(int64 n_l = 0; n_l <= num_worker_local_avail; n_l++) {
@@ -95,19 +100,7 @@ void grid_search(int64 num_worker_remote_avail, int64 num_worker_local_avail,
95100
auto p = test_set[index];
96101
num_worker_remote_target = p.first;
97102
num_worker_local_target = p.second;
98-
}
99103

100-
Status DecideTargetWorkersGridSearch(
101-
const experimental::DispatcherConfig& dispatcher_config,
102-
const ::tensorflow::data::easl::MetadataStore& metadata_store,
103-
const std::string& dataset_key,
104-
int64 num_worker_remote_avail,
105-
int64 num_worker_local_avail,
106-
int64& num_worker_remote_target,
107-
int64& num_worker_local_target) {
108-
std::time_t t = std::time(nullptr);
109-
records.push_back(t);
110-
grid_search(num_worker_remote_avail, num_worker_local_avail, num_worker_remote_target, num_worker_local_target);
111104
VLOG(0) << "DSL (DecideTargetWorkersGridSearch)" << "\n"
112105
<< "Available remote: " << num_worker_remote_avail << "\n"
113106
<< "Available local: " << num_worker_local_avail << "\n"

tensorflow/core/data/service/easl/local_workers_utils.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,10 @@ namespace local_workers_utils {
2323
Status ShouldUseLocalWorkers(
2424
const experimental::DispatcherConfig& dispatcher_config,
2525
const ::tensorflow::data::easl::MetadataStore& metadata_store,
26-
const std::string& dataset_key,
26+
const int64 dataset_key,
2727
bool& should_use_local_workers);
2828

2929
Status DecideTargetWorkersGridSearch(
30-
const experimental::DispatcherConfig& dispatcher_config,
31-
const ::tensorflow::data::easl::MetadataStore& metadata_store,
32-
const std::string& dataset_key,
3330
int64 num_worker_remote_avail,
3431
int64 num_worker_local_avail,
3532
int64& num_worker_remote_target,

tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,29 +1019,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
10191019
}
10201020

10211021
if (enqueue_result && !result.end_of_sequence) {
1022-
uint64 current_micro_timestamp = Env::Default()->NowMicros();
1023-
std::string data_source = task.info.worker_address();
1024-
bool if_local = false;
1025-
int result_size = result.element.size();
1026-
if (local_tasks_.contains(task.info.worker_address())) {
1027-
if_local = true;
1028-
local_results_buffer_.push(std::move(result));
1029-
} else {
10301022
results_.push(std::move(result));
1031-
}
1032-
1033-
const char* log_location = std::getenv("EASL_MUYU_WORKER_METRICS");
1034-
if (log_location) {
1035-
std::ofstream file(log_location, std::ios_base::app);
1036-
1037-
file << current_micro_timestamp << ","
1038-
<< data_source << ","
1039-
<< if_local << ","
1040-
<< result_size << "\n";
1041-
1042-
file.flush();
1043-
file.clear();
1044-
}
10451023
}
10461024
get_next_cv_.notify_all();
10471025
}

tensorflow/core/protobuf/service_config.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ message DispatcherConfig {
4444
// The interval at which the dispatcher should dump log files.
4545
int64 log_dumps_interval_ms = 14;
4646
// MUYU's modification
47-
int64 avg_bytes_per_element_local_thres = 15;
47+
int64 avg_bytes_per_element_local_workers_threshold = 15;
4848
}
4949

5050
// Configuration for a tf.data service WorkerServer.

0 commit comments

Comments
 (0)